How to use parlai - 10 common examples

To help you get started, we’ve selected a few parlai examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github facebookresearch / ParlAI / tests / tasks / test_wizard_of_wikipedia.py View on Github external
def run_display_test(self, kwargs):
        f = io.StringIO()
        with redirect_stdout(f):
            parser = setup_args()
            parser.set_defaults(**kwargs)
            opt = parser.parse_args()
            agent = RepeatLabelAgent(opt)
            world = create_task(opt, agent)
            display(opt)

        str_output = f.getvalue()
        self.assertTrue(
            '[ loaded {} episodes with a total of {} examples ]'.format(
                world.num_episodes(), world.num_examples()
            )
            in str_output,
            'Wizard of Wikipedia failed with following args: {}'.format(opt),
        )
github facebookresearch / ParlAI / parlai / tasks / mctest / agents.py View on Github external
from parlai.core.teachers import FbDialogTeacher
from .build import build

import copy
import os


def _path(opt, filtered):
    # Build the data if it doesn't exist.
    build(opt)
    dt = opt['datatype'].split(':')[0]
    return os.path.join(opt['datapath'], 'MCTest', dt + filtered + '.txt')


class Task160Teacher(FbDialogTeacher):
    def __init__(self, opt, shared=None):
        opt = copy.deepcopy(opt)
        opt['datafile'] = _path(opt, '160')
        super().__init__(opt, shared)


class Task500Teacher(FbDialogTeacher):
    def __init__(self, opt, shared=None):
        opt = copy.deepcopy(opt)
        opt['datafile'] = _path(opt, '500')
        super().__init__(opt, shared)


class DefaultTeacher(Task500Teacher):
    pass
github facebookresearch / ParlAI / parlai / tasks / booktest / build.py View on Github external
def build(opt):
    dpath = os.path.join(opt['datapath'], 'BookTest')
    version = None

    if not build_data.built(dpath, version_string=version):
        print('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        for downloadable_file in RESOURCES:
            downloadable_file.download_file(dpath)

        # Mark the data as built.
        build_data.mark_done(dpath, version_string=version)
github facebookresearch / ParlAI / tests / integration / test_downloads.py View on Github external
def test_qangaroo(self):
        from parlai.core.params import ParlaiParser
        from parlai.tasks.qangaroo.agents import DefaultTeacher

        opt = ParlaiParser().parse_args(args=self.args)
        opt['datatype'] = 'train'
        teacher = DefaultTeacher(opt)
        reply = teacher.act()
        check(opt, reply)

        shutil.rmtree(self.TMP_PATH)
github facebookresearch / ParlAI / tests / test_tga.py View on Github external
def test_file_inference(self):
        """
        Test --inference with older model files.
        """
        testing_utils.download_unittest_models()
        with testing_utils.capture_output():
            pp = ParlaiParser(True, True)
            opt = pp.parse_args(
                ['--model-file', 'zoo:unittest/transformer_generator2/model']
            )
            agent = create_agent(opt, True)
            self.assertEqual(agent.opt['inference'], 'greedy')

        with testing_utils.capture_output():
            pp = ParlaiParser(True, True)
            opt = pp.parse_args(
                [
                    '--model-file',
                    'zoo:unittest/transformer_generator2/model',
                    '--beam-size',
                    '5',
                ],
                print_args=False,
            )
            agent = create_agent(opt, True)
            self.assertEqual(agent.opt['inference'], 'beam')
github facebookresearch / ParlAI / tests / test_pytorch_data_teacher.py View on Github external
def get_acts_epochs_1_and_2(defaults):
            parser.set_defaults(**defaults)
            opt = parser.parse_args()
            build_dict(opt)
            agent = create_agent(opt)
            world_data = create_task(opt, agent)
            acts_epoch_1 = []
            acts_epoch_2 = []
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_1.append(world_data.acts[0])
            world_data.reset()
            while not world_data.epoch_done():
                world_data.parley()
                acts_epoch_2.append(world_data.acts[0])
            acts_epoch_1 = [bb for b in acts_epoch_1 for bb in b]
            acts_epoch_1 = sorted(
                [b for b in acts_epoch_1 if 'text' in b], key=lambda x: x.get('text')
            )
            acts_epoch_2 = [bb for b in acts_epoch_2 for bb in b]
            acts_epoch_2 = sorted(
                [b for b in acts_epoch_2 if 'text' in b], key=lambda x: x.get('text')
github facebookresearch / ParlAI / tests / test_params.py View on Github external
with testing_utils.capture_output() as _:
                modfn = os.path.join(tmp, 'model')
                with open(modfn, 'w') as f:
                    f.write('Test.')
                optfn = modfn + '.opt'
                base_opt = {
                    'model': 'tests.test_params:_ExampleUpgradeOptAgent',
                    'dict_file': modfn + '.dict',
                    'model_file': modfn,
                }
                with open(optfn, 'w') as f:
                    json.dump(base_opt, f)

                pp = ParlaiParser(True, True)
                opt = pp.parse_args(['--model-file', modfn])
                agents.create_agent(opt)
github facebookresearch / ParlAI / tests / test_dict.py View on Github external
def test_gpt2_bpe_tokenize(self):
        with testing_utils.capture_output():
            opt = Opt({'dict_tokenizer': 'gpt2', 'datapath': './data'})
            agent = DictionaryAgent(opt)
        self.assertEqual(
            # grinning face emoji
            agent.gpt2_tokenize(u'Hello, ParlAI! \U0001f600'),
            [
                'Hello',
                ',',
                r'\xc4\xa0Par',
                'l',
                'AI',
                '!',
                r'\xc4\xa0\xc3\xb0\xc5\x81\xc4\xba',
                r'\xc4\xa2',
            ],
        )
        self.assertEqual(
            agent.vec2txt(
github facebookresearch / ParlAI / parlai / tasks / mctest / build.py View on Github external
def build(opt):
    dpath = os.path.join(opt['datapath'], 'MCTest')
    version = None

    if not build_data.built(dpath, version_string=version):
        print('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        fname = 'mctest.tar.gz'
        url = 'http://parl.ai/downloads/mctest/' + fname
        build_data.download(url, dpath, fname)
        build_data.untar(dpath, fname)

        dpext = os.path.join(dpath, 'mctest')
        create_fb_format(
            dpath, 'train160', os.path.join(dpext, 'MCTest', 'mc160.train'), None
        )
        create_fb_format(
            dpath, 'valid160', os.path.join(dpext, 'MCTest', 'mc160.dev'), None
        )
        create_fb_format(
            dpath,
            'test160',
            os.path.join(dpext, 'MCTest', 'mc160.test'),
            os.path.join(dpext, 'MCTestAnswers', 'mc160.test.ans'),
        )
        create_fb_format(
            dpath, 'train500', os.path.join(dpext, 'MCTest', 'mc500.train'), None
github facebookresearch / ParlAI / tests / test_loader.py View on Github external
def test_load_agent(self):
        agent_module = load_agent_module(OPTIONS['agent'])
        self.assertEqual(agent_module, RepeatLabelAgent)