How to use sockeye - 10 common examples

To help you get started, we’ve selected a few sockeye examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github awslabs / sockeye / sockeye / encoder.py View on Github external
def __init__(self,
                 rnn_config: rnn.RNNConfig,
                 prefix=C.BIDIRECTIONALRNN_PREFIX,
                 layout=C.TIME_MAJOR,
                 encoder_class: Callable = RecurrentEncoder) -> None:
        utils.check_condition(rnn_config.num_hidden % 2 == 0,
                              "num_hidden must be a multiple of 2 for BiDirectionalRNNEncoders.")
        super().__init__(rnn_config.dtype)
        self.rnn_config = rnn_config
        self.internal_rnn_config = rnn_config.copy(num_hidden=rnn_config.num_hidden // 2)
        if layout[0] == 'N':
            logger.warning("Batch-major layout for encoder input. Consider using time-major layout for faster speed")

        # time-major layout as _encode needs to swap layout for SequenceReverse
        self.forward_rnn = encoder_class(rnn_config=self.internal_rnn_config,
                                         prefix=prefix + C.FORWARD_PREFIX,
                                         layout=C.TIME_MAJOR)
        self.reverse_rnn = encoder_class(rnn_config=self.internal_rnn_config,
                                         prefix=prefix + C.REVERSE_PREFIX,
                                         layout=C.TIME_MAJOR)
        self.layout = layout
        self.prefix = prefix
github awslabs / sockeye / test / integration / test_seq_copy_int.py View on Github external
def _test_parameter_averaging(model_path: str):
    """
    Runs parameter averaging with all available strategies
    """
    for strategy in C.AVERAGE_CHOICES:
        points = sockeye.average.find_checkpoints(model_path=model_path,
                                                  size=4,
                                                  strategy=strategy,
                                                  metric=C.PERPLEXITY)
        assert len(points) > 0
        averaged_params = sockeye.average.average(points)
        assert averaged_params
github awslabs / sockeye / test / unit / test_utils.py View on Github external
def test_check_condition_true():
    utils.check_condition(1 == 1, "Nice")
github awslabs / sockeye / test / unit / test_utils.py View on Github external
def _get_later_major_version():
    release, major, minor = utils.parse_version(__version__)
    return "%s.%d.%s" % (release, int(major) + 1, minor)
github awslabs / sockeye / test / unit / test_utils.py View on Github external
def test_expand_requested_device_ids_exception(requested_device_ids, num_gpus_available):
    with pytest.raises(ValueError):
        utils._expand_requested_device_ids(requested_device_ids, num_gpus_available)
github awslabs / sockeye / test / system / test_seq_copy_sys.py View on Github external
use_source_factor, perplexity_thresh, bleu_thresh):
    """Task: sort short sequences of digits"""
    with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _TRAIN_LINE_COUNT_EMPTY, _LINE_MAX_LENGTH, _DEV_LINE_COUNT, _LINE_MAX_LENGTH,
                            _TEST_LINE_COUNT, _TEST_LINE_COUNT_EMPTY, _TEST_MAX_LENGTH,
                            sort_target=True, seed_train=_SEED_TRAIN_DATA, seed_dev=_SEED_DEV_DATA,
                            with_source_factors=use_source_factor) as data:
        data = check_train_translate(train_params=train_params,
                                     translate_params=translate_params,
                                     data=data,
                                     use_prepared_data=use_prepared_data,
                                     max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
                                     compare_output=True,
                                     seed=seed)

        # get best validation perplexity
        metrics = sockeye.utils.read_metrics_file(os.path.join(data['model'], C.METRICS_NAME))
        perplexity = min(m[C.PERPLEXITY + '-val'] for m in metrics)

        # compute metrics
        bleu = sockeye.evaluate.raw_corpus_bleu(hypotheses=data['test_outputs'], references=data['test_targets'])
        chrf = sockeye.evaluate.raw_corpus_chrf(hypotheses=data['test_outputs'], references=data['test_targets'])
        bleu_restrict = sockeye.evaluate.raw_corpus_bleu(hypotheses=data['test_outputs_restricted'],
                                                         references=data['test_targets'])

        logger.info("test: %s", name)
        logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf)
        assert perplexity <= perplexity_thresh
        assert bleu >= bleu_thresh
        assert bleu_restrict >= bleu_thresh
github awslabs / sockeye / test / unit / test_data_io.py View on Github external
train_max_length = 30
    dev_line_count = 20
    dev_max_length = 30
    expected_mean = 1.0
    expected_std = 0.0
    test_line_count = 20
    test_line_count_empty = 0
    test_max_length = 30
    batch_size = 5
    with tmp_digits_dataset("tmp_corpus",
                            train_line_count, train_line_count_empty, train_max_length - C.SPACE_FOR_XOS,
                            dev_line_count, dev_max_length - C.SPACE_FOR_XOS,
                            test_line_count, test_line_count_empty,
                            test_max_length - C.SPACE_FOR_XOS) as data:
        # tmp common vocab
        vcb = vocab.build_from_paths([data['train_source'], data['train_target']])

        train_iter, val_iter, config_data, data_info = data_io.get_training_data_iters(
            sources=[data['train_source']],
            target=data['train_target'],
            validation_sources=[data['dev_source']],
            validation_target=data['dev_target'],
            source_vocabs=[vcb],
            target_vocab=vcb,
            source_vocab_paths=[None],
            target_vocab_path=None,
            shared_vocab=True,
            batch_size=batch_size,
            batch_by_words=False,
            batch_num_devices=1,
            max_seq_len_source=train_max_length,
            max_seq_len_target=train_max_length,
github awslabs / sockeye / test / unit / test_config.py View on Github external
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import tempfile
import os

import pytest

from sockeye import config


class ConfigTest(config.Config):
    yaml_tag = "!ConfigTest"

    def __init__(self, param, config=None):
        super().__init__()
        self.param = param
        self.config = config


def test_base_freeze():
    c = config.Config()
    c.param = 1
    assert c.param == 1
    c.freeze()
    with pytest.raises(AttributeError) as e:
        c.param = 2
    assert str(e.value) == "Cannot set 'param' in frozen config"
github awslabs / sockeye / test / unit / test_checkpoint.py View on Github external
def create_parallel_sentence_iter(source_sentences, target_sentences, max_len, batch_size, batch_by_words):
    buckets = sockeye.data_io.define_parallel_buckets(max_len, max_len, 10)
    batch_num_devices = 1
    eos = 0
    pad = 1
    unk = 2
    bucket_iterator = sockeye.data_io.ParallelBucketSentenceIter(source_sentences,
                                                                 target_sentences,
                                                                 buckets,
                                                                 batch_size,
                                                                 batch_by_words,
                                                                 batch_num_devices,
                                                                 eos, pad, unk)
    return bucket_iterator
github awslabs / sockeye / test / unit / test_arguments.py View on Github external
def test_io_args(test_params, expected_params):
    _test_args(test_params, expected_params, arguments.add_training_io_args)