Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, sampling_rate, duration=3, direction='random',
shift_max=3, shift_direction='both',
name='Shift_Aug', verbose=0):
super().__init__(
action=Action.SUBSTITUTE, name=name, device='cpu', verbose=verbose)
if shift_direction != 'both':
print(WarningMessage.DEPRECATED.format('shift_direction', '0.0.12', 'direction'))
direction = shift_direction
if shift_max != 3:
print(WarningMessage.DEPRECATED.format('shift_max', '0.0.12', 'duration'))
duration = shift_max
self.model = self.get_model(sampling_rate, duration, direction)
def _get_aug_idxes(self, tokens):
aug_cnt = self.generate_aug_cnt(len(tokens))
word_idxes = self.pre_skip_aug(tokens, tuple_idx=0)
word_idxes = self.skip_aug(word_idxes, tokens)
if len(word_idxes) == 0:
if self.verbose > 0:
exception = WarningException(name=WarningName.OUT_OF_VOCABULARY,
code=WarningCode.WARNING_CODE_002, msg=WarningMessage.NO_WORD)
exception.output()
return None
if len(word_idxes) < aug_cnt:
aug_cnt = len(word_idxes)
aug_idexes = self.sample(word_idxes, aug_cnt)
return aug_idexes
def _validate_augment(cls, data):
if data is None or len(data) == 0:
return [WarningException(name=WarningName.INPUT_VALIDATION_WARNING,
code=WarningCode.WARNING_CODE_001, msg=WarningMessage.LENGTH_IS_ZERO)]
return []
def _get_aug_idxes(self, tokens):
aug_cnt = self.generate_aug_cnt(len(tokens))
word_idxes = self.pre_skip_aug(tokens)
word_idxes = self.skip_aug(word_idxes, tokens)
if len(word_idxes) == 0:
if self.verbose > 0:
exception = WarningException(name=WarningName.OUT_OF_VOCABULARY,
code=WarningCode.WARNING_CODE_002, msg=WarningMessage.NO_WORD)
exception.output()
return None
if len(word_idxes) < aug_cnt:
aug_cnt = len(word_idxes)
aug_probs = self.model.cal_tfidf(word_idxes, tokens)
aug_idxes = []
# It is possible that no token is picked. So re-try
retry_cnt = 3
possible_idxes = word_idxes.copy()
for _ in range(retry_cnt):
for i, p in zip(possible_idxes, aug_probs):
if self.prob() < p:
aug_idxes.append(i)
possible_idxes.remove(i)
def _get_aug_idxes(self, tokens):
aug_cnt = self.generate_aug_cnt(len(tokens))
word_idxes = []
for i, t in enumerate(tokens):
token = t
if not self.case_sensitive:
token = token.lower()
if token in self.stopwords:
word_idxes.append(i)
word_idxes = self.skip_aug(word_idxes, tokens)
if len(word_idxes) == 0:
if self.verbose > 0:
exception = Warning(name=WarningName.OUT_OF_VOCABULARY,
code=WarningCode.WARNING_CODE_002, msg=WarningMessage.NO_WORD)
exception.output()
return None
if len(word_idxes) < aug_cnt:
aug_cnt = len(word_idxes)
aug_idexes = self.sample(word_idxes, aug_cnt)
return aug_idexes
def __init__(self, zone=(0.2, 0.8), coverage=1.,
factor=(0.5, 2), loudness_factor=(0.5, 2), name='Loudness_Aug', verbose=0):
super().__init__(
action=Action.SUBSTITUTE, name=name, device='cpu', verbose=verbose)
if loudness_factor != (0.5, 2):
print(WarningMessage.DEPRECATED.format('loudness_factor', '0.0.12', 'factor'))
factor = loudness_factor
self.model = self.get_model(zone, coverage, factor)