Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
examples = get_examples_from_df(data)
else:
examples = to_predict
no_cache = True
cached_features_file = os.path.join(args["cache_dir"], "cached_{}_{}_{}_{}_{}".format(mode, args["model_type"], args["max_seq_length"], self.num_labels, len(examples)))
if not os.path.isdir(self.args["cache_dir"]):
os.mkdir(self.args["cache_dir"])
if os.path.exists(cached_features_file) and not args["reprocess_input_data"] and not no_cache:
features = torch.load(cached_features_file)
print(f"Features loaded from cache at {cached_features_file}")
else:
print(f"Converting to features started.")
features = convert_examples_to_features(
examples,
self.labels,
self.args['max_seq_length'],
self.tokenizer,
# XLNet has a CLS token at the end
cls_token_at_end=bool(args["model_type"] in ["xlnet"]),
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if args["model_type"] in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
# RoBERTa uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
sep_token_extra=bool(args["model_type"] in ["roberta"]),
# PAD on the left for XLNet
pad_on_left=bool(args["model_type"] in ["xlnet"]),
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args["model_type"] in ["xlnet"] else 0,
pad_token_label_id=self.pad_token_label_id,