How to use the torchtext.data.Field function in torchtext

To help you get started, we’ve selected a few torchtext examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / text / test / data / test_dataset.py View on Github external
def test_csv_dataset_quotechar(self):
        # Based on issue #349
        example_data = [("text", "label"),
                        ('" hello world', "0"),
                        ('goodbye " world', "1"),
                        ('this is a pen " ', "0")]

        with tempfile.NamedTemporaryFile(dir=self.test_dir) as f:
            for example in example_data:
                f.write(six.b("{}\n".format(",".join(example))))

            TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
            fields = {
                "label": ("label", data.Field(use_vocab=False,
                                              sequential=False)),
                "text": ("text", TEXT)
            }

            f.seek(0)

            dataset = data.TabularDataset(
                path=f.name, format="csv",
                skip_header=False, fields=fields,
                csv_reader_params={"quotechar": None})

            TEXT.build_vocab(dataset)

            self.assertEqual(len(dataset), len(example_data) - 1)

            for i, example in enumerate(dataset):
github pytorch / text / test / data / test_dataset.py View on Github external
def test_dataset_split_arguments(self):
        num_examples, num_labels = 30, 3
        self.write_test_splitting_dataset(num_examples=num_examples,
                                          num_labels=num_labels)
        text_field = data.Field()
        label_field = data.LabelField()
        fields = [('text', text_field), ('label', label_field)]

        dataset = data.TabularDataset(
            path=self.test_dataset_splitting_path, format="csv", fields=fields)

        # Test default split ratio (0.7)
        expected_train_size = 21
        expected_test_size = 9

        train, test = dataset.split()
        assert len(train) == expected_train_size
        assert len(test) == expected_test_size

        # Test array arguments with same ratio
        split_ratio = [0.7, 0.3]
github pytorch / text / test / data / test_dataset.py View on Github external
def test_errors(self):
        # Ensure that trying to retrieve a key not in JSON data errors
        self.write_test_ppid_dataset(data_format="json")

        question_field = data.Field(sequential=True)
        label_field = data.Field(sequential=False)
        fields = {"qeustion1": ("q1", question_field),
                  "question2": ("q2", question_field),
                  "label": ("label", label_field)}

        with self.assertRaises(ValueError):
            data.TabularDataset(
                path=self.test_ppid_dataset_path, format="json", fields=fields)
github IBM / pytorch-seq2seq / tests / test_fields.py View on Github external
def test_targetfield_specials(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        data_path = os.path.join(test_path, 'data/eng-fra.txt')
        field = TargetField()
        train = torchtext.data.TabularDataset(
            path=data_path, format='tsv',
            fields=[('src', torchtext.data.Field()), ('trg', field)]
        )
        self.assertTrue(field.sos_id is None)
        self.assertTrue(field.eos_id is None)
        field.build_vocab(train)
        self.assertFalse(field.sos_id is None)
        self.assertFalse(field.eos_id is None)
github kolloldas / torchnlp / tests / tasks / test_sequence_tagging.py View on Github external
def udpos_dataset(batch_size):
    # Setup fields with batch dimension first
    inputs = data.Field(init_token="", eos_token="", batch_first=True)
    tags = data.Field(init_token="", eos_token="", batch_first=True)
    
    # Download and the load default data.
    train, val, test = datasets.UDPOS.splits(
    fields=(('inputs_word', inputs), ('labels', tags), (None, None)))
    
    # Build vocab
    inputs.build_vocab(train.inputs)
    tags.build_vocab(train.tags)
    
    # Get iterators
    train_iter, val_iter, test_iter = data.BucketIterator.splits(
                            (train, val, test), batch_size=batch_size, 
                            device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
    train_iter.repeat = False
    return train_iter, val_iter, test_iter, inputs, tags
github tunz / transformer-pytorch / dataset / translation.py View on Github external
if load_preprocessed:
        print("Loading preprocessed data...")
        src_field = torch.load(data_dir + '/source.pt')['field']
        trg_field = torch.load(data_dir + '/target.pt')['field']

        data_paths = glob.glob(data_dir + '/examples-train-*.pt')
        examples_train = torch.load(data_paths[0])
        examples_val = torch.load(data_dir + '/examples-val-0.pt')

        fields = [('src', src_field), ('trg', trg_field)]
        train = WMT32k(examples_train, fields, filter_pred=filter_pred)
        val = WMT32k(examples_val, fields, filter_pred=filter_pred)
    else:
        src_field = data.Field(tokenize=tokenize_de, batch_first=True,
                               pad_token=pad, lower=True, eos_token='')
        trg_field = data.Field(tokenize=tokenize_en, batch_first=True,
                               pad_token=pad, lower=True, eos_token='')

        print("Loading data... (this may take a while)")
        train, val, data_paths = \
            WMT32k.splits(exts=('.de', '.en'),
                          fields=(src_field, trg_field),
                          data_dir=data_dir,
                          filter_pred=filter_pred)

        print("Building vocabs... (this may take a while)")
        build_vocabs(src_field, trg_field, data_paths)

    print("Creating iterators...")
    train_iter, val_iter = common.BucketByLengthIterator.splits(
        (train, val),
        data_paths=data_paths,
github Impavidity / pbase / pbase_old / charField.py View on Github external
from torchtext.data import Field, Dataset
from torchtext.vocab import Vocab
from collections import Counter, OrderedDict
from torch.autograd import Variable

class CharField(Field):

    vocab_cls = Vocab

    def __init__(self, **kwargs):
        super(CharField, self).__init__(**kwargs)
        if self.preprocessing is None:
            self.preprocessing = lambda x: [list(y) for y in x]

    def build_vocab(self, *args, **kwargs):
        counter = Counter()
        sources = []
        for arg in args:
            if isinstance(arg, Dataset):
                sources += [getattr(arg, name) for name, field in
                            arg.fields.items() if field is self]
            else:
github castorini / hedwig / datasets / aapd.py View on Github external
def process_labels(string):
    """
    Returns the label string as a list of integers
    :param string:
    :return:
    """
    return [float(x) for x in string]


class AAPD(TabularDataset):
    NAME = 'AAPD'
    NUM_CLASSES = 54
    IS_MULTILABEL = True

    TEXT_FIELD = Field(batch_first=True, tokenize=clean_string, include_lengths=True)
    LABEL_FIELD = Field(sequential=False, use_vocab=False, batch_first=True, preprocessing=process_labels)

    @staticmethod
    def sort_key(ex):
        return len(ex.text)

    @classmethod
    def splits(cls, path, train=os.path.join('AAPD', 'train.tsv'),
               validation=os.path.join('AAPD', 'dev.tsv'),
               test=os.path.join('AAPD', 'test.tsv'), **kwargs):
        return super(AAPD, cls).splits(
            path, train=train, validation=validation, test=test,
            format='tsv', fields=[('label', cls.LABEL_FIELD), ('text', cls.TEXT_FIELD)]
        )

    @classmethod
    def iters(cls, path, vectors_name, vectors_cache, batch_size=64, shuffle=True, device=0, vectors=None,
github zomux / nmtlab / nmtlab / dataset / mt_dataset.py View on Github external
def __init__(self, corpus_path=None, src_corpus=None, tgt_corpus=None, src_vocab=None, tgt_vocab=None, batch_size=64, batch_type="sentence", max_length=60, n_valid_samples=1000, truncate=None):
        
        assert corpus_path is not None or (src_corpus is not None and tgt_corpus is not None)
        assert src_vocab is not None and tgt_vocab is not None
    
        self._batch_size = batch_size
        self._fixed_train_batches = None
        self._fixed_valid_batches = None
        self._max_length = max_length
        self._n_valid_samples = n_valid_samples
        
        self._src_field = torchtext.data.Field(pad_token="", preprocessing=lambda seq: ["<s>"] + seq + ["</s>"])
        self._src_vocab = self._src_field.vocab = Vocab(src_vocab)
        self._tgt_field = torchtext.data.Field(pad_token="", preprocessing=lambda seq: ["<s>"] + seq + ["</s>"])
        self._tgt_vocab = self._tgt_field.vocab = Vocab(tgt_vocab)
        # Make data
        if corpus_path is not None:
            self._data = torchtext.data.TabularDataset(
                path=corpus_path, format='tsv',
                fields=[('src', self._src_field), ('tgt', self._tgt_field)],
                filter_pred=self._len_filter
            )
        else:
            self._data = BilingualDataset(src_corpus, tgt_corpus, self._src_field, self._tgt_field, filter_pred=self._len_filter)
        # Create training and valid dataset
        examples = self._data.examples
        if truncate is not None:
            assert type(truncate) == int
            examples = examples[:truncate]
        n_train_samples = len(examples) - n_valid_samples
github michaelchen110 / Grammar-Correction / transformer / transformer_train.py View on Github external
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # device = ("cpu")
    # devices = [0, 1, 2, 3]

    #####################
    #   Data Loading    #
    #####################
    BOS_WORD = '<s>'
    EOS_WORD = '</s>'
    BLANK_WORD = ""
    MIN_FREQ = 2

    spacy_en = spacy.load('en')
    def tokenize_en(text):
        return [tok.text for tok in spacy_en.tokenizer(text)]
    TEXT = data.Field(tokenize=tokenize_en, init_token = BOS_WORD,
                     eos_token = EOS_WORD, pad_token=BLANK_WORD)

    train = datasets.TranslationDataset(path=os.path.join(SRC_DIR, DATA),
            exts=('.train.src', '.train.trg'), fields=(TEXT, TEXT))
    val = datasets.TranslationDataset(path=os.path.join(SRC_DIR, DATA), 
            exts=('.val.src', '.val.trg'), fields=(TEXT, TEXT))

    train_iter = MyIterator(train, batch_size=BATCH_SIZE, device=device,
                            repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn, train=True)
    valid_iter = MyIterator(val, batch_size=BATCH_SIZE, device=device,
                            repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
                            batch_size_fn=batch_size_fn, train=False)

    random_idx = random.randint(0, len(train) - 1)
    print(train[random_idx].src)