Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle
from kraken.lib.codec import PytorchCodec
from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms
logger.info('Building ground truth set from {} line images'.format(len(ground_truth) + len(training_files)))
completed_epochs = 0
# load model if given. if a new model has to be created we need to do that
# after data set initialization, otherwise to output size is still unknown.
nn = None
#hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output']
if load:
logger.info('Loading existing model from {} '.format(load))
message('Loading existing model from {}'.format(load), nl=False)
nn = vgsl.TorchVGSLModel.load_model(load)
#if nn.user_metadata and load_hyper_parameters:
# for param in hyper_fields:
# if param in nn.user_metadata:
# logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param]))
# message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param]))
# locals()[param] = nn.user_metadata[param]
message('\u2713', fg='green', nl=False)
# preparse input sizes from vgsl string to seed ground truth data set
# sizes and dimension ordering.
if not nn:
spec = spec.strip()
if spec[0] != '[' or spec[-1] != ']':
raise click.BadOptionUsage('spec', 'VGSL spec {} not bracketed'.format(spec))
blocks = spec[1:-1].split(' ')
m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0])
{'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]},
{'baseline': [[x0, ...]], 'boundary': [[x0, ...]]}
]
}: A dictionary containing the text direction and under the key 'lines'
a list of reading order sorted baselines (polylines) and their
respective polygonal boundaries. The last and first point of each
boundary polygon is connected.
Raises:
KrakenInputException if the input image is not binarized or the text
direction is invalid.
"""
im_str = get_im_str(im)
logger.info('Segmenting {}'.format(im_str))
model = vgsl.TorchVGSLModel.load_model(model)
model.eval()
if mask:
if mask.mode != '1' and not is_bitonal(mask):
logger.error('Mask is not bitonal')
raise KrakenInputException('Mask is not bitonal')
mask = mask.convert('1')
if mask.size != im.size:
logger.error('Mask size {} doesn\'t match image size {}'.format(mask.size, im.size))
raise KrakenInputException('Mask size {} doesn\'t match image size {}'.format(mask.size, im.size))
logger.info('Masking enabled in segmenter.')
mask = pil2array(mask)
batch, channels, height, width = model.input
transforms = dataset.generate_input_transforms(batch, height, width, channels, 0, valid_norm=False)
res_tf = tf.Compose(transforms.transforms[:2])
scal_im = res_tf(im).convert('L')
#! /usr/bin/env python
"""
Produces semi-transparent neural segmenter output overlays
"""
import sys
import torch
from PIL import Image
from kraken.lib import vgsl, dataset
import torch.nn.functional as F
from os.path import splitext
model = vgsl.TorchVGSLModel.load_model(sys.argv[1])
model.eval()
batch, channels, height, width = model.input
transforms = dataset.generate_input_transforms(batch, height, width, channels, 0, valid_norm=False)
imgs = sys.argv[2:]
torch.set_num_threads(1)
for img in imgs:
print(img)
im = Image.open(img)
with torch.no_grad():
o = model.nn(transforms(im).unsqueeze(0))
o = F.interpolate(o, size=im.size[::-1])
o = o.squeeze().numpy()
heat = Image.fromarray((o[1]*255).astype('uint8'))
* clstm for protobuf models generated by clstm
Args:
fname (str): Path to the model
train (bool): Enables gradient calculation and dropout layers in model.
device (str): Target device
Returns:
A kraken.lib.models.TorchSeqRecognizer object.
"""
nn = None
kind = ''
fname = abspath(expandvars(expanduser(fname)))
logger.info(u'Loading model from {}'.format(fname))
try:
nn = TorchVGSLModel.load_model(str(fname))
kind = 'vgsl'
except Exception:
try:
nn = TorchVGSLModel.load_clstm_model(fname)
kind = 'clstm'
except Exception:
nn = TorchVGSLModel.load_pronn_model(fname)
kind = 'pronn'
try:
nn = TorchVGSLModel.load_pyrnn_model(fname)
kind = 'pyrnn'
except Exception:
pass
if not nn:
raise KrakenInvalidModelException('File {} not loadable by any parser.'.format(fname))
seq = TorchSeqRecognizer(nn, train=train, device=device)