Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_load_any_proto(self):
"""
Test load_any loads protobuf models.
"""
rnn = models.load_any(os.path.join(resources, 'model.pronn'))
self.assertIsInstance(rnn, kraken.lib.models.TorchSeqRecognizer)
def test_load_invalid(self):
"""
Tests correct handling of invalid files.
"""
models.load_any(self.temp.name)
def test_load_any_proto(self):
"""
Test load_any loads protobuf models.
"""
rnn = models.load_any(os.path.join(resources, 'model.pronn'))
self.assertIsInstance(rnn, kraken.lib.models.TorchSeqRecognizer)
nm = {} # type: Dict[str, models.TorchSeqRecognizer]
ign_scripts = model.pop('ignore')
for k, v in model.items():
search = [v,
os.path.join(click.get_app_dir(APP_NAME), v),
os.path.join(LEGACY_MODEL_DIR, v)]
location = None
for loc in search:
if os.path.isfile(loc):
location = loc
break
if not location:
raise click.BadParameter('No model {} for {} found'.format(v, k))
message('Loading RNN {}\t'.format(k), nl=False)
try:
rnn = models.load_any(location, device=ctx.meta['device'])
nm[k] = rnn
except Exception:
message('\u2717', fg='red')
raise
ctx.exit(1)
message('\u2713', fg='green')
if 'default' in nm:
from collections import defaultdict
nn = defaultdict(lambda: nm['default']) # type: Dict[str, models.TorchSeqRecognizer]
nn.update(nm)
nm = nn
# thread count is global so setting it once is sufficient
nn[k].nn.set_num_threads(threads)
from kraken import pageseg
from kraken import transcribe
from kraken import binarization
from kraken.lib import models
from kraken.lib.util import is_bitonal
ti = transcribe.TranscriptionInterface(font, font_style)
if len(images) > 1 and lines:
raise click.UsageError('--lines option is incompatible with multiple image files')
if prefill:
logger.info('Loading model {}'.format(prefill))
message('Loading RNN', nl=False)
prefill = models.load_any(prefill)
message('\u2713', fg='green')
with log.progressbar(images, label='Reading images') as bar:
for fp in bar:
logger.info('Reading {}'.format(fp.name))
im = Image.open(fp)
if im.mode not in ['1', 'L', 'P', 'RGB']:
logger.warning('Input {} is in {} color mode. Converting to RGB'.format(fp.name, im.mode))
im = im.convert('RGB')
logger.info('Binarizing page')
im_bin = binarization.nlbin(im)
im_bin = im_bin.convert('1')
logger.info('Segmenting page')
if not lines:
res = pageseg.segment(im_bin, text_direction, scale, maxcolseps, black_colseps, pad=pad)
else:
if not model:
raise click.UsageError('No model to evaluate given.')
import numpy as np
from PIL import Image
from kraken.serialization import render_report
from kraken.lib import models
from kraken.lib.dataset import global_align, compute_confusions, generate_input_transforms
logger.info('Building test set from {} line images'.format(len(test_set) + len(evaluation_files)))
nn = {}
for p in model:
message('Loading model {}\t'.format(p), nl=False)
nn[p] = models.load_any(p)
message('\u2713', fg='green')
test_set = list(test_set)
# set number of OpenMP threads
logger.debug('Set OpenMP threads to {}'.format(threads))
next(iter(nn.values())).nn.set_num_threads(threads)
# merge training_files into ground_truth list
if evaluation_files:
test_set.extend(evaluation_files)
if len(test_set) == 0:
raise click.UsageError('No evaluation data was provided to the test command. Use `-e` or the `test_set` argument.')
def _get_text(im):
Returns:
{'script_detection': True, 'text_direction': '$dir', 'boxes':
[[(script, (x1, y1, x2, y2)),...]]}: A dictionary containing the text
direction and a list of lists of reading order sorted bounding boxes
under the key 'boxes' with each list containing the script segmentation
of a single line. Script is a ISO15924 4 character identifier.
Raises:
KrakenInvalidModelException if no clstm module is available.
"""
raise NotImplementedError('Temporarily unavailable. Please open a github ticket if you want this fixed sooner.')
im_str = get_im_str(im)
logger.info(u'Detecting scripts with {} in {} lines on {}'.format(model, len(bounds['boxes']), im_str))
logger.debug(u'Loading detection model {}'.format(model))
rnn = models.load_any(model)
# load numerical to 4 char identifier map
logger.debug(u'Loading label to identifier map')
with pkg_resources.resource_stream(__name__, 'iso15924.json') as fp:
n2s = json.load(fp)
# convert allowed scripts to labels
val_scripts = []
if valid_scripts:
logger.debug(u'Converting allowed scripts list {}'.format(valid_scripts))
for k, v in n2s.items():
if v in valid_scripts:
val_scripts.append(chr(int(k) + 0xF0000))
else:
valid_scripts = []
it = rpred(rnn, im, bounds, bidi_reordering=False)
preds = []
logger.debug(u'Running detection')
def __init__(self,
model: vgsl.TorchVGSLModel,
optimizer: torch.optim.Optimizer,
device: str = 'cpu',
filename_prefix: str = 'model',
event_frequency: float = 1.0,
train_set: torch.utils.data.DataLoader = None,
val_set = None,
stopper = None):
self.model = model
self.rec = models.TorchSeqRecognizer(model, train=True, device=device)
self.optimizer = optimizer
self.device = device
self.filename_prefix = filename_prefix
self.event_frequency = event_frequency
self.event_it = int(len(train_set) * event_frequency)
self.train_set = cycle(train_set)
self.val_set = val_set
self.stopper = stopper if stopper else NoStopping()
self.iterations = 0
self.lr_scheduler = None