Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_oov(self):
unknown_token = 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
texts = [
unknown_token,
unknown_token + ' the'
]
augmenters = [
naw.BertAug(action=Action.INSERT),
naw.BertAug(action=Action.SUBSTITUTE)
]
for aug in augmenters:
for text in texts:
self.assertLess(0, len(text))
augmented_text = aug.augment(text)
if aug.action == Action.INSERT:
self.assertLess(len(text.split(' ')), len(augmented_text.split(' ')))
elif aug.action == Action.SUBSTITUTE:
self.assertEqual(len(text.split(' ')), len(augmented_text.split(' ')))
else:
raise Exception('Augmenter is neither INSERT or SUBSTITUTE')
self.assertNotEqual(text, augmented_text)
self.assertTrue(nml.Bert.SUBWORD_PREFIX not in augmented_text)
def test_substitute_stopwords(self):
texts = [
'The quick brown fox jumps over the lazy dog'
]
stopwords = [t.lower() for t in texts[0].split(' ')[:3]]
aug_n = 3
aug = naw.SpellingAug(dict_path=os.environ.get("MODEL_DIR") + 'spelling_en.txt', stopwords=stopwords)
for text in texts:
self.assertLess(0, len(text))
augmented_text = aug.augment(text)
augmented_tokens = aug.tokenizer(augmented_text)
tokens = aug.tokenizer(text)
augmented_cnt = 0
for token, augmented_token in zip(tokens, augmented_tokens):
if token.lower() in stopwords and len(token) > aug_n:
self.assertEqual(token.lower(), augmented_token)
else:
augmented_cnt += 1
def test_antonyms(self):
texts = [
'Good bad'
]
aug = naw.WordNetAug(is_synonym=False)
for text in texts:
self.assertLess(0, len(text))
augmented_text = aug.augment(text)
self.assertNotEqual(text, augmented_text)
self.assertLess(0, len(texts))
naw.Word2vecAug(
model_path=os.environ.get("MODEL_DIR") + 'GoogleNews-vectors-negative300.bin',
action=Action.INSERT),
naw.FasttextAug(
model_path=os.environ.get("MODEL_DIR") + 'wiki-news-300d-1M.vec',
action=Action.INSERT),
naw.GloVeAug(
model_path=os.environ.get("MODEL_DIR") + 'glove.6B.50d.txt',
action=Action.INSERT)
]
cls.substitute_augmenters = [
naw.Word2vecAug(
model_path=os.environ.get("MODEL_DIR") + 'GoogleNews-vectors-negative300.bin',
action=Action.SUBSTITUTE),
naw.FasttextAug(
model_path=os.environ.get("MODEL_DIR") + 'wiki-news-300d-1M.vec',
action=Action.SUBSTITUTE),
naw.GloVeAug(
model_path=os.environ.get("MODEL_DIR") + 'glove.6B.50d.txt',
action=Action.SUBSTITUTE)
]
def test_substitute(self):
texts = [
'The quick brown fox jumps over the lazy dog'
]
aug = naw.FasttextAug(
model_path=os.environ.get("MODEL_DIR") + 'wiki-news-300d-1M.vec',
action=Action.SUBSTITUTE)
for text in texts:
self.assertLess(0, len(text))
augmented_text = aug.augment(text)
self.assertNotEqual(text, augmented_text)
self.assertLess(0, len(texts))
def test_multiprocess_gpu(self):
text = 'The quick brown fox jumps over the lazy dog'
n = 3
aug = naw.ContextualWordEmbsAug(force_reload=True, device='cuda')
augmented_texts = aug.augment(text, n=n, num_thread=n)
self.assertGreater(len(augmented_texts), 1)
for augmented_text in augmented_texts:
self.assertNotEqual(augmented_text, text)
def test_substitute(self):
texts = [
'The quick brown fox jumps over the lazy dog'
]
aug = naw.WordNetAug()
for text in texts:
self.assertLess(0, len(text))
augmented_text = aug.augment(text)
self.assertNotEqual(text, augmented_text)
self.assertLess(0, len(texts))
def test_oov(self):
text = 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
aug = naw.SpellingAug(dict_path=os.environ.get("MODEL_DIR") + 'spelling_en.txt')
augmented_text = aug.augment(text)
self.assertEqual(text, augmented_text)
def test_substitute(self):
texts = [
'The quick brown fox jumps over the lazy dog'
]
aug = naw.GloVeAug(
model_path=os.environ.get("MODEL_DIR") + 'glove.6B.50d.txt',
action=Action.SUBSTITUTE)
for text in texts:
self.assertLess(0, len(text))
augmented_text = aug.augment(text)
self.assertNotEqual(text, augmented_text)
self.assertLess(0, len(texts))