Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
batch_sampler=batch_sampler, batchify_fn=batchify_fn)
# data dev. For MNLI, more than one dev set is available
dev_tsv = task.dataset_dev()
dev_tsv_list = dev_tsv if isinstance(dev_tsv, list) else [dev_tsv]
loader_dev_list = []
for segment, data in dev_tsv_list:
data_dev = mx.gluon.data.SimpleDataset(list(map(trans, data)))
loader_dev = mx.gluon.data.DataLoader(data_dev, batch_size=dev_batch_size, num_workers=4,
shuffle=False, batchify_fn=batchify_fn)
loader_dev_list.append((segment, loader_dev))
# batchify for data test
test_batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0, pad_val=pad_val, round_to=args.round_to),
nlp.data.batchify.Pad(axis=0, pad_val=0, round_to=args.round_to),
nlp.data.batchify.Stack())
# transform for data test
test_trans = partial(convert_examples_to_features, tokenizer=tokenizer, truncate_length=max_len,
cls_token=vocab.cls_token if not use_roberta else vocab.bos_token,
sep_token=vocab.sep_token if not use_roberta else vocab.eos_token,
class_labels=None, is_test=True, vocab=vocab)
# data test. For MNLI, more than one test set is available
test_tsv = task.dataset_test()
test_tsv_list = test_tsv if isinstance(test_tsv, list) else [test_tsv]
loader_test_list = []
for segment, data in test_tsv_list:
data_test = mx.gluon.data.SimpleDataset(list(map(test_trans, data)))
loader_test = mx.gluon.data.DataLoader(data_test, batch_size=dev_batch_size, num_workers=4,
shuffle=False, batchify_fn=test_batchify_fn)
loader_test_list.append((segment, loader_test))
'ctx': ctx,
'use_decoder': False,
'dropout': args.dropout,
'attention_dropout': args.attention_dropout
}
# model, vocabulary and tokenizer
xlnet_base, vocab, tokenizer = model.get_model(**get_model_params)
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Stack('int32'), # example_id
nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], dtype='int32',
round_to=args.round_to), # input_ids
nlp.data.batchify.Pad(axis=0, pad_val=3, dtype='int32', round_to=args.round_to), # segment_ids
nlp.data.batchify.Stack('float32'), # valid_length
nlp.data.batchify.Pad(axis=0, pad_val=1, round_to=args.round_to), # p_mask
nlp.data.batchify.Stack('float32'), # start_position
nlp.data.batchify.Stack('float32'), # end_position
nlp.data.batchify.Stack('float32')) # is_impossible
if pretrained_xlnet_parameters:
# only load XLnetModel parameters
nlp.utils.load_parameters(xlnet_base, pretrained_xlnet_parameters, ctx=ctx, ignore_extra=True,
cast_dtype=True)
units = xlnet_base._net._units
net = XLNetForQA(xlnet_base=xlnet_base, start_top_n=args.start_top_n, end_top_n=args.end_top_n,
units=units)
net_eval = XLNetForQA(xlnet_base=xlnet_base, start_top_n=args.start_top_n,
end_top_n=args.end_top_n, units=units, is_eval=True,
params=net.collect_params())
def __call__(self, dataset):
"""create data loader based on the dataset chunk"""
if isinstance(dataset, nlp.data.NumpyDataset):
lengths = dataset.get_field('valid_lengths')
elif isinstance(dataset, BERTPretrainDataset):
lengths = dataset.transform(lambda input_ids, segment_ids, masked_lm_positions, \
masked_lm_ids, masked_lm_weights, \
next_sentence_labels, valid_lengths: \
valid_lengths, lazy=False)
else:
raise ValueError('unexpected dataset type: %s'%str(dataset))
# A batch includes: input_id, masked_id, masked_position, masked_weight,
# next_sentence_label, segment_id, valid_length
batchify_fn = Tuple(Pad(), Pad(), Pad(), Pad(), Stack(), Pad(), Stack())
if self._use_avg_len:
# sharded data loader
sampler = nlp.data.FixedBucketSampler(lengths=lengths,
# batch_size per shard
batch_size=self._batch_size,
num_buckets=self._num_buckets,
shuffle=self._shuffle,
use_average_length=True,
num_shards=self._num_ctxes)
dataloader = nlp.data.ShardedDataLoader(dataset,
batch_sampler=sampler,
batchify_fn=batchify_fn,
num_workers=self._num_ctxes)
else:
sampler = nlp.data.FixedBucketSampler(lengths,
batch_size=self._batch_size * self._num_ctxes,
num_files = len(nlp.utils.glob(data))
logging.info('%d files are found.', num_files)
assert num_files >= num_parts, \
'The number of text files must be no less than the number of ' \
'workers/partitions (%d). Only %d files at %s are found.'%(num_parts, num_files, data)
dataset_params = {'tokenizer': tokenizer, 'max_seq_length': max_seq_length,
'short_seq_prob': short_seq_prob, 'masked_lm_prob': masked_lm_prob,
'max_predictions_per_seq': max_predictions_per_seq, 'vocab':vocab,
'whole_word_mask': whole_word_mask}
sampler_params = {'batch_size': batch_size, 'shuffle': shuffle,
'num_ctxes': num_ctxes, 'num_buckets': num_buckets}
dataset_fn = prepare_pretrain_text_dataset
sampler_fn = prepare_pretrain_bucket_sampler
pad_val = vocab[vocab.padding_token]
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(pad_val=pad_val, round_to=8), # input_id
nlp.data.batchify.Pad(pad_val=pad_val), # masked_id
nlp.data.batchify.Pad(pad_val=0), # masked_position
nlp.data.batchify.Pad(pad_val=0), # masked_weight
nlp.data.batchify.Stack(), # next_sentence_label
nlp.data.batchify.Pad(pad_val=0, round_to=8), # segment_id
nlp.data.batchify.Stack())
split_sampler = nlp.data.SplitSampler(num_files, num_parts=num_parts,
part_index=part_idx, repeat=repeat)
dataloader = nlp.data.DatasetLoader(data,
file_sampler=split_sampler,
dataset_fn=dataset_fn,
batch_sampler_fn=sampler_fn,
dataset_params=dataset_params,
batch_sampler_params=sampler_params,
batchify_fn=batchify_fn,
num_dataset_workers=num_dataset_workers,
get_model_params = {
'name': args.model,
'dataset_name': args.dataset,
'pretrained': get_pretrained,
'ctx': ctx,
'use_decoder': False,
'dropout': args.dropout,
'attention_dropout': args.attention_dropout
}
# model, vocabulary and tokenizer
xlnet_base, vocab, tokenizer = model.get_model(**get_model_params)
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Stack('int32'), # example_id
nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], dtype='int32',
round_to=args.round_to), # input_ids
nlp.data.batchify.Pad(axis=0, pad_val=3, dtype='int32', round_to=args.round_to), # segment_ids
nlp.data.batchify.Stack('float32'), # valid_length
nlp.data.batchify.Pad(axis=0, pad_val=1, round_to=args.round_to), # p_mask
nlp.data.batchify.Stack('float32'), # start_position
nlp.data.batchify.Stack('float32'), # end_position
nlp.data.batchify.Stack('float32')) # is_impossible
if pretrained_xlnet_parameters:
# only load XLnetModel parameters
nlp.utils.load_parameters(xlnet_base, pretrained_xlnet_parameters, ctx=ctx, ignore_extra=True,
cast_dtype=True)
units = xlnet_base._net._units
net = XLNetForQA(xlnet_base=xlnet_base, start_top_n=args.start_top_n, end_top_n=args.end_top_n,
units=units)
def make_dataloader(data_train, data_val, data_test, args,
use_average_length=False, num_shards=0, num_workers=8):
"""Create data loaders for training/validation/test."""
data_train_lengths = get_data_lengths(data_train)
data_val_lengths = get_data_lengths(data_val)
data_test_lengths = get_data_lengths(data_test)
train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'))
test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'),
btf.Stack())
target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
if args.bucket_scheme == 'constant':
bucket_scheme = nlp.data.ConstWidthBucket()
elif args.bucket_scheme == 'linear':
bucket_scheme = nlp.data.LinearWidthBucket()
elif args.bucket_scheme == 'exp':
bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)
else:
raise NotImplementedError
train_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_train_lengths,
batch_size=args.batch_size,
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
vocab=vocab,
class_labels=task.class_labels,
label_alias=task.label_alias,
pad=pad, pair=task.is_pair,
has_label=True)
# data train
# task.dataset_train returns (segment_name, dataset)
train_tsv = task.dataset_train()[1]
data_train = mx.gluon.data.SimpleDataset(pool.map(trans, train_tsv))
data_train_len = data_train.transform(
lambda input_id, length, segment_id, label_id: length, lazy=False)
# bucket sampler for training
pad_val = vocab[vocab.padding_token]
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0, pad_val=pad_val), # input
nlp.data.batchify.Stack(), # length
nlp.data.batchify.Pad(axis=0, pad_val=0), # segment
nlp.data.batchify.Stack(label_dtype)) # label
batch_sampler = nlp.data.sampler.FixedBucketSampler(
data_train_len,
batch_size=batch_size,
num_buckets=10,
ratio=0,
shuffle=True)
# data loader for training
loader_train = gluon.data.DataLoader(
dataset=data_train,
num_workers=num_workers,
batch_sampler=batch_sampler,
batchify_fn=batchify_fn)