Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
@raises(KrakenInvalidModelException)
def test_load_any_pyrnn_py3(self):
"""
Test load_any doesn't load pickled models on python 3
"""
rnn = models.load_any(os.path.join(resources, 'model.pyrnn.gz'))
@raises(KrakenInvalidModelException)
def test_load_invalid(self):
"""
Tests correct handling of invalid files.
"""
models.load_any(self.temp.name)
def load_pyrnn_model(cls, path: str):
"""
Loads an pyrnn model to VGSL.
"""
if not PY2:
raise KrakenInvalidModelException('Loading pickle models is not supported on python 3')
import cPickle
def find_global(mname, cname):
aliases = {
'lstm.lstm': kraken.lib.lstm,
'ocrolib.lstm': kraken.lib.lstm,
'ocrolib.lineest': kraken.lib.lineest,
}
if mname in aliases:
return getattr(aliases[mname], cname)
return getattr(sys.modules[mname], cname)
of = io.open
if path.endswith(u'.gz'):
of = gzip.open
def load_pronn_model(cls, path: str):
"""
Loads an pronn model to VGSL.
"""
with open(path, 'rb') as fp:
net = pyrnn_pb2.pyrnn()
try:
net.ParseFromString(fp.read())
except Exception:
raise KrakenInvalidModelException('File does not contain valid proto msg')
if not net.IsInitialized():
raise KrakenInvalidModelException('Model incomplete')
# extract codec
codec = PytorchCodec(net.codec)
input = net.ninput
hidden = net.fwdnet.wgi.dim[0]
# extract weights
weightnames = ('wgi', 'wgf', 'wci', 'wgo', 'wip', 'wfp', 'wop')
fwd_w = []
rev_w = []
for w in weightnames:
fwd_ar = getattr(net.fwdnet, w)
def load_clstm_model(cls, path: str):
"""
Loads an CLSTM model to VGSL.
"""
net = clstm_pb2.NetworkProto()
with open(path, 'rb') as fp:
try:
net.ParseFromString(fp.read())
except Exception:
raise KrakenInvalidModelException('File does not contain valid proto msg')
if not net.IsInitialized():
raise KrakenInvalidModelException('Model incomplete')
input = net.ninput
attrib = {a.key: a.value for a in list(net.attribute)}
# mainline clstm model
if len(attrib) > 1:
mode = 'clstm'
else:
mode = 'clstm_compat'
# extract codec
codec = PytorchCodec([u''] + [chr(x) for x in net.codec[1:]])
# separate layers
nets = {}
nets['softm'] = [n for n in list(net.sub) if n.kind == 'SoftmaxLayer'][0]
parallel = [n for n in list(net.sub) if n.kind == 'Parallel'][0]
if mname in aliases:
return getattr(aliases[mname], cname)
return getattr(sys.modules[mname], cname)
of = io.open
if path.endswith(u'.gz'):
of = gzip.open
with io.BufferedReader(of(path, 'rb')) as fp:
unpickler = cPickle.Unpickler(fp)
unpickler.find_global = find_global
try:
net = unpickler.load()
except Exception as e:
raise KrakenInvalidModelException(str(e))
if not isinstance(net, kraken.lib.lstm.SeqRecognizer):
raise KrakenInvalidModelException('Pickle is %s instead of '
'SeqRecognizer' %
type(net).__name__)
# extract codec
codec = PytorchCodec({k: [v] for k, v in net.codec.char2code.items()})
input = net.Ni
parallel, softmax = net.lstm.nets
fwdnet, revnet = parallel.nets
revnet = revnet.net
hidden = fwdnet.WGI.shape[0]
# extract weights
weightnames = ('WGI', 'WGF', 'WCI', 'WGO', 'WIP', 'WFP', 'WOP')
fwd_w = []
def load_clstm_model(cls, path: str):
"""
Loads an CLSTM model to VGSL.
"""
net = clstm_pb2.NetworkProto()
with open(path, 'rb') as fp:
try:
net.ParseFromString(fp.read())
except Exception:
raise KrakenInvalidModelException('File does not contain valid proto msg')
if not net.IsInitialized():
raise KrakenInvalidModelException('Model incomplete')
input = net.ninput
attrib = {a.key: a.value for a in list(net.attribute)}
# mainline clstm model
if len(attrib) > 1:
mode = 'clstm'
else:
mode = 'clstm_compat'
# extract codec
codec = PytorchCodec([u''] + [chr(x) for x in net.codec[1:]])
# separate layers
nets = {}
'ocrolib.lineest': kraken.lib.lineest,
}
if mname in aliases:
return getattr(aliases[mname], cname)
return getattr(sys.modules[mname], cname)
of = io.open
if path.endswith(u'.gz'):
of = gzip.open
with io.BufferedReader(of(path, 'rb')) as fp:
unpickler = cPickle.Unpickler(fp)
unpickler.find_global = find_global
try:
net = unpickler.load()
except Exception as e:
raise KrakenInvalidModelException(str(e))
if not isinstance(net, kraken.lib.lstm.SeqRecognizer):
raise KrakenInvalidModelException('Pickle is %s instead of '
'SeqRecognizer' %
type(net).__name__)
# extract codec
codec = PytorchCodec({k: [v] for k, v in net.codec.char2code.items()})
input = net.Ni
parallel, softmax = net.lstm.nets
fwdnet, revnet = parallel.nets
revnet = revnet.net
hidden = fwdnet.WGI.shape[0]
# extract weights
weightnames = ('WGI', 'WGF', 'WCI', 'WGO', 'WIP', 'WFP', 'WOP')
nn = TorchVGSLModel.load_model(str(fname))
kind = 'vgsl'
except Exception:
try:
nn = TorchVGSLModel.load_clstm_model(fname)
kind = 'clstm'
except Exception:
nn = TorchVGSLModel.load_pronn_model(fname)
kind = 'pronn'
try:
nn = TorchVGSLModel.load_pyrnn_model(fname)
kind = 'pyrnn'
except Exception:
pass
if not nn:
raise KrakenInvalidModelException('File {} not loadable by any parser.'.format(fname))
seq = TorchSeqRecognizer(nn, train=train, device=device)
seq.kind = kind
return seq
def load_pronn_model(cls, path: str):
"""
Loads an pronn model to VGSL.
"""
with open(path, 'rb') as fp:
net = pyrnn_pb2.pyrnn()
try:
net.ParseFromString(fp.read())
except Exception:
raise KrakenInvalidModelException('File does not contain valid proto msg')
if not net.IsInitialized():
raise KrakenInvalidModelException('Model incomplete')
# extract codec
codec = PytorchCodec(net.codec)
input = net.ninput
hidden = net.fwdnet.wgi.dim[0]
# extract weights
weightnames = ('wgi', 'wgf', 'wci', 'wgo', 'wip', 'wfp', 'wop')
fwd_w = []
rev_w = []
for w in weightnames:
fwd_ar = getattr(net.fwdnet, w)
rev_ar = getattr(net.revnet, w)
fwd_w.append(torch.Tensor(fwd_ar.value).view(list(fwd_ar.dim)))