How to use the sockeye.data_io.ParallelDataSet function in sockeye

To help you get started, we’ve selected a few sockeye examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github awslabs / sockeye / test / unit / test_data_io.py View on Github external
def test_parallel_data_set_permute():
    batch_size = 5
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))
    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5)).fill_up(
        bucket_batch_sizes)

    permutations, inverse_permutations = data_io.get_permutations(dataset.get_bucket_counts())

    assert len(permutations) == len(inverse_permutations) == len(dataset)
    dataset_restored = dataset.permute(permutations).permute(inverse_permutations)
    assert len(dataset) == len(dataset_restored)
    for buck_idx in range(len(dataset)):
        num_samples = dataset.source[buck_idx].shape[0]
        if num_samples:
            assert (dataset.source[buck_idx] == dataset_restored.source[buck_idx]).asnumpy().all()
            assert (dataset.target[buck_idx] == dataset_restored.target[buck_idx]).asnumpy().all()
            assert (dataset.label[buck_idx] == dataset_restored.label[buck_idx]).asnumpy().all()
        else:
            assert not dataset_restored.source[buck_idx]
            assert not dataset_restored.target[buck_idx]
github awslabs / sockeye / test / unit / test_data_io.py View on Github external
def test_sharded_parallel_sample_iter():
    batch_size = 2
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    # The first bucket is going to be empty:
    bucket_counts = [0] + [None] * (len(buckets) - 1)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))

    dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
                                                                  bucket_counts=bucket_counts))
    dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
                                                                  bucket_counts=bucket_counts))

    with TemporaryDirectory() as work_dir:
        shard1_fname = os.path.join(work_dir, 'shard1')
        shard2_fname = os.path.join(work_dir, 'shard2')
        dataset1.save(shard1_fname)
        dataset2.save(shard2_fname)
        shard_fnames = [shard1_fname, shard2_fname]

        it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes, 'replicate')

        # Test 1
        it.next()
        expected_batch = it.next()

        fname = os.path.join(work_dir, "saved_iter")
github awslabs / sockeye / test / unit / test_data_io.py View on Github external
num_shards = 2
    batch_size = 2
    num_batches_per_bucket = 10
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets]
    num_batches_per_shard = num_batches_per_bucket * len(buckets)
    num_batches = num_shards * num_batches_per_shard
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))

    dataset1 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
                                                                  bucket_counts=bucket_counts))
    dataset2 = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
                                                                  bucket_counts=bucket_counts))
    with TemporaryDirectory() as work_dir:
        shard1_fname = os.path.join(work_dir, 'shard1')
        shard2_fname = os.path.join(work_dir, 'shard2')
        dataset1.save(shard1_fname)
        dataset2.save(shard2_fname)
        shard_fnames = [shard1_fname, shard2_fname]

        it = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes,
                                               'replicate')

        num_batches_seen = 0
        while it.iter_next():
            it.next()
            num_batches_seen += 1
        assert num_batches_seen == num_batches
github awslabs / sockeye / test / unit / test_data_io.py View on Github external
def test_sharded_and_parallel_iter_same_num_batches():
    """ Tests that a sharded data iterator with just a single shard produces as many shards as an iterator directly
    using the same dataset. """
    batch_size = 2
    num_batches_per_bucket = 10
    buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
    bucket_counts = [batch_size * num_batches_per_bucket for _ in buckets]
    num_batches = num_batches_per_bucket * len(buckets)
    bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
                                                           batch_size,
                                                           batch_by_words=False,
                                                           batch_num_devices=1,
                                                           data_target_average_len=[None] * len(buckets))

    dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
                                                                 bucket_counts=bucket_counts))

    with TemporaryDirectory() as work_dir:
        shard_fname = os.path.join(work_dir, 'shard1')
        dataset.save(shard_fname)
        shard_fnames = [shard_fname]

        it_sharded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes,
                                                       'replicate')

        it_parallel = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)

        num_batches_seen = 0
        while it_parallel.iter_next():
            assert it_sharded.iter_next()
            it_parallel.next()
github awslabs / sockeye / test / unit / image_captioning / test_data_io.py View on Github external
def test_raw_list_text_dset_loader(source_list, target_sentences, num_samples_per_bucket,
                                   expected_source_0, expected_target_0, expected_label_0):
    # Test Init object
    buckets = sockeye.data_io.define_parallel_buckets(4, 4, 1, 1.0)
    dset_loader = data_io.RawListTextDatasetLoader(buckets=buckets,
                                       eos_id=10, pad_id=C.PAD_ID)

    assert isinstance(dset_loader, data_io.RawListTextDatasetLoader)
    assert len(dset_loader.buckets)==3

    # Test Load data
    pop_dset_loader = dset_loader.load(source_list, target_sentences, num_samples_per_bucket)

    assert isinstance(pop_dset_loader, sockeye.data_io.ParallelDataSet)
    assert len(pop_dset_loader.source)==3
    assert len(pop_dset_loader.target)==3
    assert len(pop_dset_loader.label)==3
    np.testing.assert_equal(pop_dset_loader.source[0], expected_source_0)
    np.testing.assert_almost_equal(pop_dset_loader.target[0].asnumpy(), expected_target_0)
    np.testing.assert_almost_equal(pop_dset_loader.label[0].asnumpy(), expected_label_0)
github Cartus / DCGCN / sockeye / data_io.py View on Github external
if isinstance(self.source[buck_idx], np.ndarray):
                    source.append(self.source[buck_idx].take(np.int64(permutation.asnumpy())))
                else:
                    source.append(self.source[buck_idx].take(permutation))
                target.append(self.target[buck_idx].take(permutation))
                label.append(self.label[buck_idx].take(permutation))
                graph.append(self.src_graphs[buck_idx].take(permutation))
                position.append(self.src_positions[buck_idx].take(permutation))
            else:
                source.append(self.source[buck_idx])
                target.append(self.target[buck_idx])
                label.append(self.label[buck_idx])
                graph.append(self.src_graphs[buck_idx])
                position.append(self.src_positions[buck_idx])

        return ParallelDataSet(source, target, label, graph, position)
github awslabs / sockeye / sockeye / data_io.py View on Github external
# we can try again to compute the label sequence on the fly in next().
            data_label[buck_index][sample_index, :target_len] = target[1:] + [self.eos_id]

            bucket_sample_index[buck_index] += 1

        for i in range(len(data_source)):
            data_source[i] = mx.nd.array(data_source[i], dtype=self.dtype)
            data_target[i] = mx.nd.array(data_target[i], dtype=self.dtype)
            data_label[i] = mx.nd.array(data_label[i], dtype=self.dtype)

        if num_tokens_source > 0 and num_tokens_target > 0:
            logger.info("Created bucketed parallel data set. Introduced padding: source=%.1f%% target=%.1f%%)",
                        num_pad_source / num_tokens_source * 100,
                        num_pad_target / num_tokens_target * 100)

        return ParallelDataSet(data_source, data_target, data_label)
github Cartus / DCGCN / sockeye / data_io.py View on Github external
def _load_shard(self):
        shard_fname = self.shards_fnames[self.shard_index]
        logger.info("Loading shard %s.", shard_fname)
        dataset = ParallelDataSet.load(self.shards_fnames[self.shard_index]).fill_up(self.bucket_batch_sizes,
                                                                                     self.fill_up,
                                                                                     seed=self.shard_index)
        self.shard_iter = ParallelSampleIter(data=dataset,
                                             buckets=self.buckets,
                                             batch_size=self.batch_size,
                                             bucket_batch_sizes=self.bucket_batch_sizes,
                                             source_data_name=self.source_data_name,
                                             target_data_name=self.target_data_name,
                                             num_factors=self.num_factors)
github Cartus / DCGCN / sockeye / data_io.py View on Github external
data_source[i] = mx.nd.array(data_source[i], dtype=self.dtype)
            data_target[i] = mx.nd.array(data_target[i], dtype=self.dtype)
            data_label[i] = mx.nd.array(data_label[i], dtype=self.dtype)

            data_src_graphs[i], global_index_list = self._convert_to_adj_matrix(self.buckets[i][0], data_src_graphs[i])
            data_src_positions[i] = self._get_graph_positions(self.buckets[i][0], data_src_graphs[i], global_index_list)

            data_src_graphs[i] = mx.nd.array(data_src_graphs[i], dtype=self.dtype)
            data_src_positions[i] = mx.nd.array(data_src_positions[i], dtype=self.dtype)

        if num_tokens_source > 0 and num_tokens_target > 0:
            logger.info("Created bucketed parallel data set. Introduced padding: source=%.1f%% target=%.1f%%)",
                        num_pad_source / num_tokens_source * 100,
                        num_pad_target / num_tokens_target * 100)

        return ParallelDataSet(data_source, data_target, data_label, data_src_graphs, data_src_positions)