Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def setUp(self):
# codec mapping one code point to one label
self.o2o_codec = codec.PytorchCodec('ab')
# codec mapping many code points to one label
self.m2o_codec = codec.PytorchCodec(['aaa' , 'aa', 'a', 'b'])
# codec mapping one code point to many labels
self.o2m_codec = codec.PytorchCodec({'a': [10, 11, 12], 'b': [12, 45, 80]})
# codec mapping many code points to many labels
self.m2m_codec = codec.PytorchCodec({'aaa': [10, 11, 12], 'aa': [10, 10], 'a': [10], 'bb': [15], 'b': [12]})
self.invalid_c_sequence = 'aaababbcaaa'
self.valid_c_sequence = 'aaababbaaabbbb'
self.invalid_l_sequence = [(45, 78, 778, 0.3793492615638364),
(10, 203, 859, 0.9485075253700872),
(11, 70, 601, 0.7885297329523855),
(12, 251, 831, 0.7216817042926938),
(900, 72, 950, 0.27609823017048707)]
def setUp(self):
# codec mapping one code point to one label
self.o2o_codec = codec.PytorchCodec('ab')
# codec mapping many code points to one label
self.m2o_codec = codec.PytorchCodec(['aaa' , 'aa', 'a', 'b'])
# codec mapping one code point to many labels
self.o2m_codec = codec.PytorchCodec({'a': [10, 11, 12], 'b': [12, 45, 80]})
# codec mapping many code points to many labels
self.m2m_codec = codec.PytorchCodec({'aaa': [10, 11, 12], 'aa': [10, 10], 'a': [10], 'bb': [15], 'b': [12]})
self.invalid_c_sequence = 'aaababbcaaa'
self.valid_c_sequence = 'aaababbaaabbbb'
self.invalid_l_sequence = [(45, 78, 778, 0.3793492615638364),
(10, 203, 859, 0.9485075253700872),
(11, 70, 601, 0.7885297329523855),
(12, 251, 831, 0.7216817042926938),
(900, 72, 950, 0.27609823017048707)]
def setUp(self):
# codec mapping one code point to one label
self.o2o_codec = codec.PytorchCodec('ab')
# codec mapping many code points to one label
self.m2o_codec = codec.PytorchCodec(['aaa' , 'aa', 'a', 'b'])
# codec mapping one code point to many labels
self.o2m_codec = codec.PytorchCodec({'a': [10, 11, 12], 'b': [12, 45, 80]})
# codec mapping many code points to many labels
self.m2m_codec = codec.PytorchCodec({'aaa': [10, 11, 12], 'aa': [10, 10], 'a': [10], 'bb': [15], 'b': [12]})
self.invalid_c_sequence = 'aaababbcaaa'
self.valid_c_sequence = 'aaababbaaabbbb'
self.invalid_l_sequence = [(45, 78, 778, 0.3793492615638364),
(10, 203, 859, 0.9485075253700872),
(11, 70, 601, 0.7885297329523855),
(12, 251, 831, 0.7216817042926938),
(900, 72, 950, 0.27609823017048707)]
def setUp(self):
# codec mapping one code point to one label
self.o2o_codec = codec.PytorchCodec('ab')
# codec mapping many code points to one label
self.m2o_codec = codec.PytorchCodec(['aaa' , 'aa', 'a', 'b'])
# codec mapping one code point to many labels
self.o2m_codec = codec.PytorchCodec({'a': [10, 11, 12], 'b': [12, 45, 80]})
# codec mapping many code points to many labels
self.m2m_codec = codec.PytorchCodec({'aaa': [10, 11, 12], 'aa': [10, 10], 'a': [10], 'bb': [15], 'b': [12]})
self.invalid_c_sequence = 'aaababbcaaa'
self.valid_c_sequence = 'aaababbaaabbbb'
self.invalid_l_sequence = [(45, 78, 778, 0.3793492615638364),
(10, 203, 859, 0.9485075253700872),
(11, 70, 601, 0.7885297329523855),
(12, 251, 831, 0.7216817042926938),
(900, 72, 950, 0.27609823017048707)]
# prefer explicitly given codec over network codec if mode is 'both'
codec = codec if (codec and resize == 'both') else nn.codec
try:
gt_set.encode(codec)
except KrakenEncodeException as e:
message('Network codec not compatible with training set')
alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys()))
if resize == 'fail':
logger.error('Training data and model codec alphabets mismatch: {}'.format(alpha_diff))
ctx.exit(code=1)
elif resize == 'add':
message('Adding missing labels to network ', nl=False)
logger.info('Resizing codec to include {} new code points'.format(len(alpha_diff)))
codec.c2l.update({k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label()+1)})
nn.add_codec(PytorchCodec(codec.c2l))
logger.info('Resizing last layer in network to {} outputs'.format(codec.max_label()+1))
nn.resize_output(codec.max_label()+1)
gt_set.encode(nn.codec)
message('\u2713', fg='green')
elif resize == 'both':
message('Fitting network exactly to training set ', nl=False)
logger.info('Resizing network or given codec to {} code sequences'.format(len(gt_set.alphabet)))
gt_set.encode(None)
ncodec, del_labels = codec.merge(gt_set.codec)
logger.info('Deleting {} output classes from network ({} retained)'.format(len(del_labels), len(codec)-len(del_labels)))
gt_set.encode(ncodec)
nn.resize_output(ncodec.max_label()+1, del_labels)
message('\u2713', fg='green')
else:
raise click.BadOptionUsage('resize', 'Invalid resize value {}'.format(resize))
else:
net.ParseFromString(fp.read())
except Exception:
raise KrakenInvalidModelException('File does not contain valid proto msg')
if not net.IsInitialized():
raise KrakenInvalidModelException('Model incomplete')
input = net.ninput
attrib = {a.key: a.value for a in list(net.attribute)}
# mainline clstm model
if len(attrib) > 1:
mode = 'clstm'
else:
mode = 'clstm_compat'
# extract codec
codec = PytorchCodec([u''] + [chr(x) for x in net.codec[1:]])
# separate layers
nets = {}
nets['softm'] = [n for n in list(net.sub) if n.kind == 'SoftmaxLayer'][0]
parallel = [n for n in list(net.sub) if n.kind == 'Parallel'][0]
nets['lstm1'] = [n for n in list(parallel.sub) if n.kind.startswith('NPLSTM')][0]
rev = [n for n in list(parallel.sub) if n.kind == 'Reversed'][0]
nets['lstm2'] = rev.sub[0]
hidden = int(nets['lstm1'].attribute[0].value)
weights = {} # type: Dict[str, torch.Tensor]
for n in nets:
weights[n] = {}
for w in list(nets[n].weights):
weights[n][w.name] = torch.Tensor(w.value).view(list(w.dim))
def load_pronn_model(cls, path: str):
"""
Loads an pronn model to VGSL.
"""
with open(path, 'rb') as fp:
net = pyrnn_pb2.pyrnn()
try:
net.ParseFromString(fp.read())
except Exception:
raise KrakenInvalidModelException('File does not contain valid proto msg')
if not net.IsInitialized():
raise KrakenInvalidModelException('Model incomplete')
# extract codec
codec = PytorchCodec(net.codec)
input = net.ninput
hidden = net.fwdnet.wgi.dim[0]
# extract weights
weightnames = ('wgi', 'wgf', 'wci', 'wgo', 'wip', 'wfp', 'wop')
fwd_w = []
rev_w = []
for w in weightnames:
fwd_ar = getattr(net.fwdnet, w)
rev_ar = getattr(net.revnet, w)
fwd_w.append(torch.Tensor(fwd_ar.value).view(list(fwd_ar.dim)))
rev_w.append(torch.Tensor(rev_ar.value).view(list(rev_ar.dim)))
t = torch.cat(fwd_w[:4])
for l in rm_labels:
if l in v:
rm_labels.remove(l)
# iteratively remove labels, decrementing subsequent labels to close
# (new) holes in the codec.
offset_rm_labels = [v-idx for idx, v in enumerate(sorted(set(rm_labels)))]
for rlabel in offset_rm_labels:
c2l_cand = {k: [l-1 if l > rlabel else l for l in v] for k, v in c2l_cand.items()}
# add mappings not in original codec
add_list = {cseq: enc for cseq, enc in codec.c2l.items() if cseq not in self.c2l}
# renumber
start_idx = max(label for v in c2l_cand.values() for label in v) + 1
add_labels = {k: v for v, k in enumerate(sorted(set(label for v in add_list.values() for label in v)), start_idx)}
for k, v in add_list.items():
c2l_cand[k] = [add_labels[label] for label in v]
return PytorchCodec(c2l_cand), set(rm_labels)
of = io.open
if path.endswith(u'.gz'):
of = gzip.open
with io.BufferedReader(of(path, 'rb')) as fp:
unpickler = cPickle.Unpickler(fp)
unpickler.find_global = find_global
try:
net = unpickler.load()
except Exception as e:
raise KrakenInvalidModelException(str(e))
if not isinstance(net, kraken.lib.lstm.SeqRecognizer):
raise KrakenInvalidModelException('Pickle is %s instead of '
'SeqRecognizer' %
type(net).__name__)
# extract codec
codec = PytorchCodec({k: [v] for k, v in net.codec.char2code.items()})
input = net.Ni
parallel, softmax = net.lstm.nets
fwdnet, revnet = parallel.nets
revnet = revnet.net
hidden = fwdnet.WGI.shape[0]
# extract weights
weightnames = ('WGI', 'WGF', 'WCI', 'WGO', 'WIP', 'WFP', 'WOP')
fwd_w = []
rev_w = []
for w in weightnames:
fwd_w.append(torch.Tensor(getattr(fwdnet, w)))
rev_w.append(torch.Tensor(getattr(revnet, w)))