Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_parallel_sample_iter():
batch_size = 2
buckets = data_io.define_parallel_buckets(100, 100, 10, 1.0)
# The first bucket is going to be empty:
bucket_counts = [0] + [None] * (len(buckets) - 1)
bucket_batch_sizes = data_io.define_bucket_batch_sizes(buckets,
batch_size,
batch_by_words=False,
batch_num_devices=1,
data_target_average_len=[None] * len(buckets))
dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
bucket_counts=bucket_counts))
it = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)
with TemporaryDirectory() as work_dir:
# Test 1
it.next()
expected_batch = it.next()
fname = os.path.join(work_dir, "saved_iter")
it.save_state(fname)
it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)
it_loaded.reset()
it_loaded.load_state(fname)
loaded_batch = it_loaded.next()
assert _data_batches_equal(expected_batch, loaded_batch)
# Test 2
it.reset()
expected_batch = it.next()
it.save_state(fname)
it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)
it_loaded.reset()
it_loaded.load_state(fname)
loaded_batch = it_loaded.next()
assert _data_batches_equal(expected_batch, loaded_batch)
# Test 3
it.reset()
expected_batch = it.next()
it.save_state(fname)
it_loaded = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)
it_loaded.reset()
it_loaded.load_state(fname)
loaded_batch = it_loaded.next()
assert _data_batches_equal(expected_batch, loaded_batch)
while it.iter_next():
it.next()
it_loaded.next()
assert not it_loaded.iter_next()
batch_by_words=False,
batch_num_devices=1,
data_target_average_len=[None] * len(buckets))
dataset = data_io.ParallelDataSet(*_get_random_bucketed_data(buckets, min_count=0, max_count=5,
bucket_counts=bucket_counts))
with TemporaryDirectory() as work_dir:
shard_fname = os.path.join(work_dir, 'shard1')
dataset.save(shard_fname)
shard_fnames = [shard_fname]
it_sharded = data_io.ShardedParallelSampleIter(shard_fnames, buckets, batch_size, bucket_batch_sizes,
'replicate')
it_parallel = data_io.ParallelSampleIter(dataset, buckets, batch_size, bucket_batch_sizes)
num_batches_seen = 0
while it_parallel.iter_next():
assert it_sharded.iter_next()
it_parallel.next()
it_sharded.next()
num_batches_seen += 1
assert num_batches_seen == num_batches
print("Resetting...")
it_sharded.reset()
it_parallel.reset()
num_batches_seen = 0
while it_parallel.iter_next():
assert it_sharded.iter_next()
def _load_shard(self):
shard_fname = self.shards_fnames[self.shard_index]
logger.info("Loading shard %s.", shard_fname)
dataset = ParallelDataSet.load(self.shards_fnames[self.shard_index]).fill_up(self.bucket_batch_sizes,
self.fill_up,
seed=self.shard_index)
self.shard_iter = ParallelSampleIter(data=dataset,
buckets=self.buckets,
batch_size=self.batch_size,
bucket_batch_sizes=self.bucket_batch_sizes,
source_data_name=self.source_data_name,
target_data_name=self.target_data_name,
num_factors=self.num_factors)
edge_vocab)
validation_data_statistics = get_data_statistics(validation_sources_sentences,
validation_target_sentences,
buckets,
validation_length_statistics.length_ratio_mean,
validation_length_statistics.length_ratio_std,
source_vocabs, target_vocab)
validation_data_statistics.log(bucket_batch_sizes)
validation_data = data_loader.load(validation_sources_sentences, validation_target_sentences, validation_graphs,
validation_data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes,
fill_up)
return ParallelSampleIter(data=validation_data,
buckets=buckets,
batch_size=batch_size,
bucket_batch_sizes=bucket_batch_sizes,
num_factors=len(validation_sources))
def _load_shard(self):
shard_fname = self.shards_fnames[self.shard_index]
logger.info("Loading shard %s.", shard_fname)
dataset = ParallelDataSet.load(self.shards_fnames[self.shard_index]).fill_up(self.bucket_batch_sizes,
seed=self.shard_index)
self.shard_iter = ParallelSampleIter(data=dataset,
buckets=self.buckets,
batch_size=self.batch_size,
bucket_batch_sizes=self.bucket_batch_sizes,
source_data_name=self.source_data_name,
target_data_name=self.target_data_name,
num_factors=self.num_factors,
permute=self.permute)
validation_target,
source_vocabs, target_vocab)
validation_data_statistics = get_data_statistics(validation_sources_sentences,
validation_target_sentences,
buckets,
validation_length_statistics.length_ratio_mean,
validation_length_statistics.length_ratio_std,
source_vocabs, target_vocab)
validation_data_statistics.log(bucket_batch_sizes)
validation_data = data_loader.load(validation_sources_sentences, validation_target_sentences,
validation_data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes)
return ParallelSampleIter(data=validation_data,
buckets=buckets,
batch_size=batch_size,
bucket_batch_sizes=bucket_batch_sizes,
num_factors=len(validation_sources))
data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes)
data_info = DataInfo(sources=sources,
target=target,
source_vocabs=source_vocab_paths,
target_vocab=target_vocab_path,
shared_vocab=shared_vocab,
num_shards=1)
config_data = DataConfig(data_statistics=data_statistics,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
num_source_factors=len(sources),
source_with_eos=True)
train_iter = ParallelSampleIter(data=training_data,
buckets=buckets,
batch_size=batch_size,
bucket_batch_sizes=bucket_batch_sizes,
num_factors=len(sources),
permute=True)
validation_iter = get_validation_data_iter(data_loader=data_loader,
validation_sources=validation_sources,
validation_target=validation_target,
buckets=buckets,
bucket_batch_sizes=bucket_batch_sizes,
source_vocabs=source_vocabs,
target_vocab=target_vocab,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
batch_size=batch_size)