How to use the sockeye.data_io.SequenceReader function in sockeye

To help you get started, we’ve selected a few sockeye examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github awslabs / sockeye / test / unit / test_data_io.py View on Github external
with TemporaryDirectory() as work_dir:
        path = os.path.join(work_dir, 'input')
        with open(path, 'w') as f:
            for sequence in sequences:
                print(sequence, file=f)

        vocabulary = vocab.build_vocab(sequences) if use_vocab else None

        reader = data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=add_bos, add_eos=add_eos)

        read_sequences = [s for s in reader]
        assert len(read_sequences) == len(sequences)

        if vocabulary is None:
            with pytest.raises(SockeyeError) as e:
                data_io.SequenceReader(path, vocabulary=vocabulary, add_bos=True)
            assert str(e.value) == "Adding a BOS or EOS symbol requires a vocabulary"

            expected_sequences = [data_io.strids2ids(get_tokens(s)) if s else None for s in sequences]
            assert read_sequences == expected_sequences
        else:
            expected_sequences = [data_io.tokens2ids(get_tokens(s), vocabulary) if s else None for s in sequences]
            if add_bos:
                expected_sequences = [[vocabulary[C.BOS_SYMBOL]] + s if s else None for s in expected_sequences]
            if add_eos:
                expected_sequences = [s + [vocabulary[C.EOS_SYMBOL]] if s else None for s in expected_sequences]
            assert read_sequences == expected_sequences
github Cartus / DCGCN / sockeye / data_io.py View on Github external
def create_graph_readers(sources: List[str], target: str, source_graph: str,
                         vocab_sources: List[vocab.Vocab],
                         vocab_target: vocab.Vocab,
                         vocab_edges: vocab.Vocab) -> Tuple[List[SequenceReader], SequenceReader, GraphReader]:
    """
    Create source readers with EOS and target readers with BOS.

    :param sources: The file names of source data and factors.
    :param target: The file name of the target data.
    :param vocab_sources: The source vocabularies.
    :param vocab_target: The target vocabularies.
    :return: The source sequence readers and the target reader.
    """
    source_sequence_readers = [SequenceReader(source, vocab, add_eos=True) for source, vocab in
                               zip(sources, vocab_sources)]
    target_sequence_reader = SequenceReader(target, vocab_target, add_bos=True)
    graph_reader = GraphReader(source_graph, vocab_edges)
    return source_sequence_readers, target_sequence_reader, graph_reader
github awslabs / sockeye / sockeye / data_io.py View on Github external
def create_sequence_readers(sources: List[str], target: str,
                            vocab_sources: List[vocab.Vocab],
                            vocab_target: vocab.Vocab) -> Tuple[List[SequenceReader], SequenceReader]:
    """
    Create source readers with EOS and target readers with BOS.

    :param sources: The file names of source data and factors.
    :param target: The file name of the target data.
    :param vocab_sources: The source vocabularies.
    :param vocab_target: The target vocabularies.
    :return: The source sequence readers and the target reader.
    """
    source_sequence_readers = [SequenceReader(source, vocab, add_eos=True) for source, vocab in
                               zip(sources, vocab_sources)]
    target_sequence_reader = SequenceReader(target, vocab_target, add_bos=True)
    return source_sequence_readers, target_sequence_reader
github Cartus / DCGCN / sockeye / data_io.py View on Github external
def create_sequence_readers(sources: List[str], target: str,
                            vocab_sources: List[vocab.Vocab],
                            vocab_target: vocab.Vocab) -> Tuple[List[SequenceReader], SequenceReader]:
    """
    Create source readers with EOS and target readers with BOS.

    :param sources: The file names of source data and factors.
    :param target: The file name of the target data.
    :param vocab_sources: The source vocabularies.
    :param vocab_target: The target vocabularies.
    :return: The source sequence readers and the target reader.
    """
    source_sequence_readers = [SequenceReader(source, vocab, add_eos=True) for source, vocab in
                               zip(sources, vocab_sources)]
    target_sequence_reader = SequenceReader(target, vocab_target, add_bos=True)
    return source_sequence_readers, target_sequence_reader
github Cartus / DCGCN / sockeye / data_io.py View on Github external
def create_sequence_readers(sources: List[str], target: str,
                            vocab_sources: List[vocab.Vocab],
                            vocab_target: vocab.Vocab) -> Tuple[List[SequenceReader], SequenceReader]:
    """
    Create source readers with EOS and target readers with BOS.

    :param sources: The file names of source data and factors.
    :param target: The file name of the target data.
    :param vocab_sources: The source vocabularies.
    :param vocab_target: The target vocabularies.
    :return: The source sequence readers and the target reader.
    """
    source_sequence_readers = [SequenceReader(source, vocab, add_eos=True) for source, vocab in
                               zip(sources, vocab_sources)]
    target_sequence_reader = SequenceReader(target, vocab_target, add_bos=True)
    return source_sequence_readers, target_sequence_reader
github awslabs / sockeye / sockeye / data_io.py View on Github external
def create_sequence_readers(sources: List[str], target: str,
                            vocab_sources: List[vocab.Vocab],
                            vocab_target: vocab.Vocab) -> Tuple[List[SequenceReader], SequenceReader]:
    """
    Create source readers with EOS and target readers with BOS.

    :param sources: The file names of source data and factors.
    :param target: The file name of the target data.
    :param vocab_sources: The source vocabularies.
    :param vocab_target: The target vocabularies.
    :return: The source sequence readers and the target reader.
    """
    source_sequence_readers = [SequenceReader(source, vocab, add_eos=True) for source, vocab in
                               zip(sources, vocab_sources)]
    target_sequence_reader = SequenceReader(target, vocab_target, add_bos=True)
    return source_sequence_readers, target_sequence_reader
github Cartus / DCGCN / sockeye / data_io.py View on Github external
def create_graph_readers(sources: List[str], target: str, source_graph: str,
                         vocab_sources: List[vocab.Vocab],
                         vocab_target: vocab.Vocab,
                         vocab_edges: vocab.Vocab) -> Tuple[List[SequenceReader], SequenceReader, GraphReader]:
    """
    Create source readers with EOS and target readers with BOS.

    :param sources: The file names of source data and factors.
    :param target: The file name of the target data.
    :param vocab_sources: The source vocabularies.
    :param vocab_target: The target vocabularies.
    :return: The source sequence readers and the target reader.
    """
    source_sequence_readers = [SequenceReader(source, vocab, add_eos=True) for source, vocab in
                               zip(sources, vocab_sources)]
    target_sequence_reader = SequenceReader(target, vocab_target, add_bos=True)
    graph_reader = GraphReader(source_graph, vocab_edges)
    return source_sequence_readers, target_sequence_reader, graph_reader
github awslabs / sockeye / sockeye / image_captioning / data_io.py View on Github external
Returns a data iterator for scoring. The iterator loads data on demand,
  batch by batch, and does not skip any lines. Lines that are too long
  are truncated.
  # TODO

  """
  logger.info("==============================")
  logger.info("Creating scoring data iterator")
  logger.info("==============================")

  # One bucket to hold them all,
  bucket = (max_seq_len_source, max_seq_len_target)
  buckets = [bucket]

  source_images = [FileListReader(sources[0], source_root)]
  target_sentences = SequenceReader(target, vocab_target, add_bos=True)

  # ...One loader to raise them,
  data_loader = RawListTextDatasetLoader(buckets=buckets,
                                         eos_id=vocab_target[C.EOS_SYMBOL],
                                         pad_id=C.PAD_ID)

  data_statistics = get_data_statistics(source_readers=None,
                                        target_reader=target_sentences,
                                        buckets=buckets,
                                        length_ratio_mean=1.0,
                                        length_ratio_std=1.0,
                                        source_vocabs=None,
                                        target_vocab=vocab_target)

  bucket_batch_sizes = define_bucket_batch_sizes(buckets,
                                                 batch_size,
github awslabs / sockeye / sockeye / data_io.py View on Github external
target_vocab=target_vocab,
                                         num_shards=num_shards,
                                         buckets=buckets,
                                         length_ratio_mean=length_statistics.length_ratio_mean,
                                         length_ratio_std=length_statistics.length_ratio_std,
                                         output_prefix=output_prefix)
    data_statistics.log()

    data_loader = RawParallelDatasetLoader(buckets=buckets,
                                           eos_id=target_vocab[C.EOS_SYMBOL],
                                           pad_id=C.PAD_ID)

    # 3. convert each shard to serialized ndarrays
    for shard_idx, (shard_sources, shard_target, shard_stats) in enumerate(shards):
        sources_sentences = [SequenceReader(s) for s in shard_sources]
        target_sentences = SequenceReader(shard_target)
        dataset = data_loader.load(sources_sentences, target_sentences, shard_stats.num_sents_per_bucket)
        shard_fname = os.path.join(output_prefix, C.SHARD_NAME % shard_idx)
        shard_stats.log()
        logger.info("Writing '%s'", shard_fname)
        dataset.save(shard_fname)

        if not keep_tmp_shard_files:
            for f in shard_sources:
                os.remove(f)
            os.remove(shard_target)

    data_info = DataInfo(sources=[os.path.abspath(fname) for fname in source_fnames],
                         target=os.path.abspath(target_fname),
                         source_vocabs=source_vocab_paths,
                         target_vocab=target_vocab_path,
                         shared_vocab=shared_vocab,