How to use the torchtext.data.Dataset function in torchtext

To help you get started, we’ve selected a few torchtext examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / text / test / data / test_field.py View on Github external
def test_build_vocab_from_dataset(self):
        nesting_field = data.Field(tokenize=list, unk_token="", pad_token="",
                                   init_token="", eos_token="")
        CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
        ex1 = data.Example.fromlist(["aaa bbb c"], [("chars", CHARS)])
        ex2 = data.Example.fromlist(["bbb aaa"], [("chars", CHARS)])
        dataset = data.Dataset([ex1, ex2], [("chars", CHARS)])

        CHARS.build_vocab(dataset, min_freq=2)

        expected = "a b   <s> </s>  ".split()
        assert len(CHARS.vocab) == len(expected)
        for c in expected:
            assert c in CHARS.vocab.stoi

        expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
        assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs
github AnubhavGupta3377 / Text-Classification-Models-Pytorch / Model_fastText / utils.py View on Github external
datafields = [("text",TEXT),("label",LABEL)]
        
        # Load data from pd.DataFrame into torchtext.data.Dataset
        train_df = self.get_pandas_df(train_file)
        train_examples = [data.Example.fromlist(i, datafields) for i in train_df.values.tolist()]
        train_data = data.Dataset(train_examples, datafields)
        
        test_df = self.get_pandas_df(test_file)
        test_examples = [data.Example.fromlist(i, datafields) for i in test_df.values.tolist()]
        test_data = data.Dataset(test_examples, datafields)
        
        # If validation file exists, load it. Otherwise get validation data from training data
        if val_file:
            val_df = self.get_pandas_df(val_file)
            val_examples = [data.Example.fromlist(i, datafields) for i in val_df.values.tolist()]
            val_data = data.Dataset(val_examples, datafields)
        else:
            train_data, val_data = train_data.split(split_ratio=0.8)
        
        TEXT.build_vocab(train_data, vectors=Vectors(w2v_file))
        self.word_embeddings = TEXT.vocab.vectors
        self.vocab = TEXT.vocab
        
        self.train_iterator = data.BucketIterator(
            (train_data),
            batch_size=self.config.batch_size,
            sort_key=lambda x: len(x.text),
            repeat=False,
            shuffle=True)
        
        self.val_iterator, self.test_iterator = data.BucketIterator.splits(
            (val_data, test_data),
github gaojun4ever / JunNMT / nmt / IO.py View on Github external
def merge_vocabs(vocabs, specials, vocab_size=None):
    """
    Merge individual vocabularies (assumed to be generated from disjoint
    documents) into a larger vocabulary.
    Args:
        vocabs: `torchtext.vocab.Vocab` vocabularies to be merged
        vocab_size: `int` the final vocabulary size. `None` for no limit.
    Return:
        `torchtext.vocab.Vocab`
    """
    merged = sum([vocab.freqs for vocab in vocabs], Counter())
    return torchtext.vocab.Vocab(merged,
                                 specials=specials,
                                 max_size=vocab_size)

class NMTDataset(torchtext.data.Dataset):

    # @staticmethod
    # def sort_key(ex):
    #     return data.interleave_keys(len(ex.src), len(ex.tgt))

    def sort_key(self, ex):
        """ Sort using length of source sentences. """
        # Default to a balanced sort, prioritizing tgt len match.
        # TODO: make this configurable.
        if hasattr(ex, "tgt"):
            return -len(ex.src), -len(ex.tgt)
        return -len(ex.src)

    def __init__(self, src_path, tgt_path, fields, **kwargs):

        make_example = torchtext.data.Example.fromlist
github arthurmensch / didyprog / didyprog / ner / externals / torchtext / data.py View on Github external
for i, arr in enumerate(arrs):
            if self.nesting_field.include_lengths:
                arr = tuple(arr)
            numericalized_ex = self.nesting_field.numericalize(
                arr, device=device, train=train)
            if self.nesting_field.include_lengths:
                numericalized_ex, lengths_ex = numericalized_ex
                numericalized[i, :, :] = numericalized_ex.data
                lengths[i, :len(lengths_ex)] = lengths_ex
        if self.nesting_field.include_lengths:
            return Variable(numericalized), lengths
        else:
            return Variable(numericalized)


class SequenceTaggingDataset(data.Dataset):
    """Defines a dataset for sequence tagging. Examples in this dataset
    contain paired lists -- paired list of words and tags.

    For example, in the case of part-of-speech tagging, an example is of the
    form
    [I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT]

    See torchtext/test/sequence_tagging.py on how to use this class.
    """

    @staticmethod
    def sort_key(example):
        for attr in dir(example):
            if not callable(getattr(example, attr)) and \
                    not attr.startswith("__"):
                return len(getattr(example, attr))
github MillionIntegrals / vel / vel / data / source / nlp / multi30k.py View on Github external
src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)

            examples = []

            with io.open(src_path, mode='r', encoding='utf-8') as src_file, \
                    io.open(trg_path, mode='r', encoding='utf-8') as trg_file:
                for src_line, trg_line in zip(src_file, trg_file):
                    src_line, trg_line = src_line.strip(), trg_line.strip()
                    if src_line != '' and trg_line != '':
                        examples.append(data.Example.fromlist(
                            [src_line, trg_line], fields))

            with open(cache_file, 'wb') as fp:
                pickle.dump(examples, file=fp)

        data.Dataset.__init__(self, examples, fields, **kwargs)
github jadore801120 / attention-is-all-you-need-pytorch / train.py View on Github external
opt.max_token_seq_len = data['settings'].max_len
    opt.src_pad_idx = data['vocab']['src'].vocab.stoi[Constants.PAD_WORD]
    opt.trg_pad_idx = data['vocab']['trg'].vocab.stoi[Constants.PAD_WORD]

    opt.src_vocab_size = len(data['vocab']['src'].vocab)
    opt.trg_vocab_size = len(data['vocab']['trg'].vocab)

    #========= Preparing Model =========#
    if opt.embs_share_weight:
        assert data['vocab']['src'].vocab.stoi == data['vocab']['trg'].vocab.stoi, \
            'To sharing word embedding the src/trg word2idx table shall be the same.'

    fields = {'src': data['vocab']['src'], 'trg':data['vocab']['trg']}

    train = Dataset(examples=data['train'], fields=fields)
    val = Dataset(examples=data['valid'], fields=fields)

    train_iterator = BucketIterator(train, batch_size=batch_size, device=device, train=True)
    val_iterator = BucketIterator(val, batch_size=batch_size, device=device)

    return train_iterator, val_iterator
github matthew-z / R-net / src / preprocess_squad.py View on Github external
def get_dataset(json_path, cache_root="./data/cache", mode="train", tokenizer=None):
    examples, fields = build_examples_from_json(json_path, mode, tokenizer)

    # examples_file_name="%s%s.examples" % (mode, "_DEBUG" if DEBUG else "")
    # example_file = os.path.join(cache_root, examples_file_name)
    #
    # if os.path.exists(example_file):
    #     print("loading examples from %s" % example_file)
    #     examples = pickle.load(open(example_file, "rb"))
    # else:
    #     print("building examples %s" % example_file)
    #     examples = extract_method(fields, json_path)
    #     pickle.dump(examples, open(example_file, "wb"))

    squad_dataset = data.Dataset(examples, fields)

    # build voc
    squad_dataset.fields["passage"].build_vocab(squad_dataset, [x.question for x in squad_dataset.examples],
                                                wv_type='glove.840B', wv_dir="./data/embedding/glove_word/",
                                                unk_init="zero")

    squad_dataset.fields["question"].vocab = squad_dataset.fields["passage"].vocab

    #
    # squad_dataset.fields["question"].build_vocab(squad_dataset, [x.passage for x in squad_dataset.examples],
    #                                              wv_type='glove.840B', wv_dir="./data/embedding/glove_word/",
    #                                              unk_init="zero")

    dataset_file_name = "%s%s.dataset" % (mode, "_DEBUG" if DEBUG else "")

    # squad_dataset.fields["answer_text"].build_vocab(squad_dataset,[x.passage for x in squad_dataset.examples], wv_type='glove.840B',
github pytorch / text / torchtext / datasets / imdb.py View on Github external
import os
import glob
import io

from .. import data


class IMDB(data.Dataset):

    urls = ['http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz']
    name = 'imdb'
    dirname = 'aclImdb'

    @staticmethod
    def sort_key(ex):
        return len(ex.text)

    def __init__(self, path, text_field, label_field, **kwargs):
        """Create an IMDB dataset instance given a path and fields.

        Arguments:
            path: Path to the dataset's highest level directory
            text_field: The field that will be used for text data.
            label_field: The field that will be used for label data.
github tacchinotacchi / distil-bilstm / generate_dataset.py View on Github external
sentences = augmentation(sentences, pos_dict)
    else:
        sentences = [text for text, _ in input_tsv]

    # Load teacher model
    model = BertForSequenceClassification.from_pretrained(args.model).to(device)
    tokenizer = BertTokenizer.from_pretrained(args.model, do_lower_case=True)

    # Assign labels with teacher
    teacher_field = data.Field(sequential=True, tokenize=tokenizer.tokenize, lower=True, include_lengths=True, batch_first=True)
    fields = [("text", teacher_field)]
    if not args.no_augment:
        examples = [data.Example.fromlist([" ".join(words)], fields) for words in sentences]
    else:
        examples = [data.Example.fromlist([text], fields) for text in sentences]
    augmented_dataset = data.Dataset(examples, fields)
    teacher_field.vocab = BertVocab(tokenizer.vocab)
    new_labels = BertTrainer(model, device, batch_size=args.batch_size).infer(augmented_dataset)

    # Write to file
    with open(args.output, "w") as f:
        f.write("sentence\tscores\n")
        for sentence, rating in zip(sentences, new_labels):
            if not args.no_augment:
                text = " ".join(sentence)
            else: text = sentence
            f.write("%s\t%.6f %.6f\n" % (text, *rating))
github pytorch / text / torchtext / datasets / sequence_tagging.py View on Github external
from .. import data
import random


class SequenceTaggingDataset(data.Dataset):
    """Defines a dataset for sequence tagging. Examples in this dataset
    contain paired lists -- paired list of words and tags.

    For example, in the case of part-of-speech tagging, an example is of the
    form
    [I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT]

    See torchtext/test/sequence_tagging.py on how to use this class.
    """

    @staticmethod
    def sort_key(example):
        for attr in dir(example):
            if not callable(getattr(example, attr)) and \
                    not attr.startswith("__"):
                return len(getattr(example, attr))