How to use the kraken.lib.vgsl.TorchVGSLModel function in kraken

To help you get started, we’ve selected a few kraken examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mittagessen / kraken / kraken / ketos.py View on Github external
elif resize == 'both':
                message('Fitting network exactly to training set ', nl=False)
                logger.info('Resizing network or given codec to {} code sequences'.format(len(gt_set.alphabet)))
                gt_set.encode(None)
                ncodec, del_labels = codec.merge(gt_set.codec)
                logger.info('Deleting {} output classes from network ({} retained)'.format(len(del_labels), len(codec)-len(del_labels)))
                gt_set.encode(ncodec)
                nn.resize_output(ncodec.max_label()+1, del_labels)
                message('\u2713', fg='green')
            else:
                raise click.BadOptionUsage('resize', 'Invalid resize value {}'.format(resize))
    else:
        gt_set.encode(codec)
        logger.info('Creating new model {} with {} outputs'.format(spec, gt_set.codec.max_label()+1))
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1)
        nn = vgsl.TorchVGSLModel(spec)
        # initialize weights
        message('Initializing model ', nl=False)
        nn.init_weights()
        nn.add_codec(gt_set.codec)
        # initialize codec
        message('\u2713', fg='green')

    # half the number of data loading processes if device isn't cuda and we haven't enabled preloading
    if device == 'cpu' and not preload:
        loader_threads = threads // 2
    else:
        loader_threads = threads
    train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True)
    threads = max(threads-loader_threads, 1)

    # don't encode validation set as the alphabets may not match causing encoding failures
github mittagessen / kraken / tests / test_vgsl.py View on Github external
def test_helper_train(self):
        """
        Tests train/eval mode helper methods
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        rnn.train()
        self.assertTrue(torch.is_grad_enabled())
        self.assertTrue(rnn.nn.training)
        rnn.eval()
        self.assertFalse(torch.is_grad_enabled())
        self.assertFalse(rnn.nn.training)
github mittagessen / kraken / tests / test_vgsl.py View on Github external
def test_save_model(self):
        """
        Test model serialization.
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        with tempfile.TemporaryDirectory() as dir:
            rnn.save_model(dir + '/foo.mlmodel')
            self.assertTrue(os.path.exists(dir + '/foo.mlmodel'))
github mittagessen / kraken / tests / test_vgsl.py View on Github external
def test_del_resize(self):
        """
        Tests resizing of output layers with entry deletion.
        """
        rnn = vgsl.TorchVGSLModel('[1,1,0,48 Lbx10 Do O1c57]')
        rnn.resize_output(80, [2, 4, 5, 6, 7, 12, 25])
        self.assertEqual(rnn.nn[-1].lin.out_features, 80)
github mittagessen / kraken / kraken / lib / models.py View on Github external
Returns:
        A kraken.lib.models.TorchSeqRecognizer object.
    """
    nn = None
    kind = ''
    fname = abspath(expandvars(expanduser(fname)))
    logger.info(u'Loading model from {}'.format(fname))
    try:
        nn = TorchVGSLModel.load_model(str(fname))
        kind = 'vgsl'
    except Exception:
        try:
            nn = TorchVGSLModel.load_clstm_model(fname)
            kind = 'clstm'
        except Exception:
            nn = TorchVGSLModel.load_pronn_model(fname)
            kind = 'pronn'
        try:
            nn = TorchVGSLModel.load_pyrnn_model(fname)
            kind = 'pyrnn'
        except Exception:
            pass
    if not nn:
        raise KrakenInvalidModelException('File {} not loadable by any parser.'.format(fname))
    seq = TorchSeqRecognizer(nn, train=train, device=device)
    seq.kind = kind
    return seq