How to use the babi.Sentence function in babi

To help you get started, we’ve selected a few babi examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github chainer / chainer / examples / memnn / test_memnn.py View on Github external
network = model.predictor
    max_memory = network.max_memory
    id_to_vocab = {i: v for v, i in vocab.items()}

    test_data = babi.read_data(vocab, args.DATA)
    print('Test data: %s: %d' % (args.DATA, len(test_data)))

    sentence_len = max(max(len(s.sentence) for s in story)
                       for story in test_data)
    correct = total = 0
    for story in test_data:
        mem = xp.zeros((max_memory, sentence_len), dtype=numpy.int32)
        i = 0
        for sent in story:
            if isinstance(sent, babi.Sentence):
                if i == max_memory:
                    mem[0:i - 1, :] = mem[1:i, :]
                    i -= 1
                mem[i, 0:len(sent.sentence)] = xp.asarray(sent.sentence)
                i += 1
            elif isinstance(sent, babi.Query):
                query = xp.array(sent.sentence, dtype=numpy.int32)

                # networks assumes mini-batch data
                score = network(mem[None], query[None])[0]
                answer = int(xp.argmax(score.array))

                if answer == sent.answer:
                    correct += 1
                total += 1
                print(id_to_vocab[answer], id_to_vocab[sent.answer])
github unnonouno / chainer-memnn / babi.py View on Github external
def parse_line(vocab, line):
    if '\t' in line:
        # question line
        question, answer, fact_id = line.split('\t')
        aid = convert(vocab, [answer])[0]
        words = split(question)
        wid = convert(vocab, words)
        ids = list(map(int, fact_id.split(' ')))
        return Query(wid, aid, ids)

    else:
        # sentence line
        words = split(line)
        wid = convert(vocab, words)
        return Sentence(wid)
github unnonouno / chainer-memnn / memnn.py View on Github external
def convert_data(train_data, max_memory):
    all_data = []
    sentence_len = max(max(len(s.sentence) for s in story)
                       for story in train_data)
    for story in train_data:
        mem = numpy.zeros((max_memory, sentence_len), dtype=numpy.int32)
        i = 0
        for sent in story:
            if isinstance(sent, babi.Sentence):
                if i == max_memory:
                    mem[0:i - 1, :] = mem[1:i, :]
                    i -= 1
                mem[i, 0:len(sent.sentence)] = sent.sentence
                i += 1
            elif isinstance(sent, babi.Query):
                query = numpy.zeros(sentence_len, dtype=numpy.int32)
                query[0:len(sent.sentence)] = sent.sentence
                all_data.append({
                    'sentences': mem.copy(),
                    'question': query,
                    'answer': numpy.array(sent.answer, 'i'),
                })

    return all_data
github chainer / chainer / examples / memnn / memnn.py View on Github external
def convert_data(train_data, max_memory):
    all_data = []
    sentence_len = max(max(len(s.sentence) for s in story)
                       for story in train_data)
    for story in train_data:
        mem = numpy.zeros((max_memory, sentence_len), dtype=numpy.int32)
        i = 0
        for sent in story:
            if isinstance(sent, babi.Sentence):
                if i == max_memory:
                    mem[0:i - 1, :] = mem[1:i, :]
                    i -= 1
                mem[i, 0:len(sent.sentence)] = sent.sentence
                i += 1
            elif isinstance(sent, babi.Query):
                query = numpy.zeros(sentence_len, dtype=numpy.int32)
                query[0:len(sent.sentence)] = sent.sentence
                all_data.append({
                    'sentences': mem.copy(),
                    'question': query,
                    'answer': numpy.array(sent.answer, numpy.int32),
                })

    return all_data