Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
lengths = [self.len_of_sent(i) for i in corpus]
if len(corpus) < 32:
n_buckets = 1
else:
n_buckets = min(self.config.n_buckets, len(corpus))
buckets = dict(zip(*kmeans(lengths, n_buckets)))
sizes, buckets = zip(*[
(size, bucket) for size, bucket in buckets.items()
])
# the number of chunks in each bucket, which is clipped by
# range [1, len(bucket)]
chunks = [min(len(bucket), max(round(size * len(bucket) / batch_size), 1)) for size, bucket in
zip(sizes, buckets)]
range_fn = randperm if shuffle else arange
max_samples_per_batch = self.config.get('max_samples_per_batch', None)
for i in tolist(range_fn(len(buckets))):
split_sizes = [(len(buckets[i]) - j - 1) // chunks[i] + 1
for j in range(chunks[i])] # how many sentences in each batch
for batch_indices in tf.split(range_fn(len(buckets[i])), split_sizes):
indices = [buckets[i][j] for j in tolist(batch_indices)]
if max_samples_per_batch:
for j in range(0, len(indices), max_samples_per_batch):
yield from self.batched_inputs_to_batches(corpus, indices[j:j + max_samples_per_batch],
shuffle)
else:
yield from self.batched_inputs_to_batches(corpus, indices, shuffle)
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
arc_preds, rel_preds, mask = Y
sents = []
for arc_sent, rel_sent, length in zip(arc_preds, rel_preds,
tf.math.count_nonzero(mask, axis=-1)):
sent = []
for arc, rel in zip(tolist(arc_sent[1:, 1:]), tolist(rel_sent[1:, 1:])):
ar = []
for idx, (a, r) in enumerate(zip(arc, rel)):
if a:
ar.append((idx + 1, self.rel_vocab.idx_to_token[r]))
if not ar:
# orphan
ar.append((0, self.orphan_relation))
sent.append(ar)
sents.append(sent)
return sents
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
arc_preds, rel_preds, mask = Y
sents = []
for arc_sent, rel_sent, length in zip(arc_preds, rel_preds,
tf.math.count_nonzero(mask, axis=-1)):
arcs = tolist(arc_sent)[1:length + 1]
rels = tolist(rel_sent)[1:length + 1]
sents.append([(a, self.rel_vocab.idx_to_token[r]) for a, r in zip(arcs, rels)])
return sents
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
arc_preds, rel_preds, mask = Y
sents = []
for arc_sent, rel_sent, length in zip(arc_preds, rel_preds,
tf.math.count_nonzero(mask, axis=-1)):
arcs = tolist(arc_sent)[1:length + 1]
rels = tolist(rel_sent)[1:length + 1]
sents.append([(a, self.rel_vocab.idx_to_token[r]) for a, r in zip(arcs, rels)])
return sents
def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
if len(X) == 2:
form_batch, cposes_batch = X
mask = tf.not_equal(form_batch, 0)
elif len(X) == 3:
form_batch, cposes_batch, mask = X
else:
raise ValueError(f'Expect X to be 2 or 3 elements but got {repr(X)}')
sents = []
for form_sent, cposes_sent, length in zip(form_batch, cposes_batch,
tf.math.count_nonzero(mask, axis=-1)):
forms = tolist(form_sent)[1:length + 1]
cposes = tolist(cposes_sent)[1:length + 1]
sents.append([(self.form_vocab.idx_to_token[f],
self.cpos_vocab.idx_to_token[c]) for f, c in zip(forms, cposes)])
return sents
def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
if len(X) == 2:
form_batch, cposes_batch = X
mask = tf.not_equal(form_batch, 0)
elif len(X) == 3:
form_batch, cposes_batch, mask = X
else:
raise ValueError(f'Expect X to be 2 or 3 elements but got {repr(X)}')
sents = []
for form_sent, cposes_sent, length in zip(form_batch, cposes_batch,
tf.math.count_nonzero(mask, axis=-1)):
forms = tolist(form_sent)[1:length + 1]
cposes = tolist(cposes_sent)[1:length + 1]
sents.append([(self.form_vocab.idx_to_token[f],
self.cpos_vocab.idx_to_token[c]) for f, c in zip(forms, cposes)])
return sents