How to use the parlai.utils.testing function in parlai

To help you get started, we’ve selected a few parlai examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github facebookresearch / ParlAI / tests / test_fairseq.py View on Github external
    @testing_utils.skipUnlessGPU
    @unittest.skipIf(SKIP_TESTS, "Fairseq not installed")
    def test_labelcands(self):
        stdout, valid, test = testing_utils.train_model(
            dict(
                task='integration_tests:candidate',
                model='fairseq',
                arch='lstm_wiseman_iwslt_de_en',
                lr=LR,
                batchsize=BATCH_SIZE,
                num_epochs=NUM_EPOCHS,
                rank_candidates=True,
                skip_generation=True,
            )
        )

        self.assertTrue(
github facebookresearch / ParlAI / tests / test_teachers.py View on Github external
def test_display_data(self):
        """Test that, with pre-loaded image features, all examples are different."""

        def _test_display_output(opt):
            output = testing_utils.display_data(opt)
            train_labels = re.findall(r"\[labels: .*\]", output[0])
            valid_labels = re.findall(r"\[eval_labels: .*\]", output[1])
            test_labels = re.findall(r"\[eval_labels: .*\]", output[2])

            for i, lbls in enumerate([train_labels, valid_labels, test_labels]):
                self.assertGreater(len(lbls), 0, 'DisplayData failed')
                self.assertEqual(len(lbls), len(set(lbls)), output[i])

        with testing_utils.tempdir() as tmpdir:
            data_path = tmpdir
            os.makedirs(os.path.join(data_path, 'ImageTeacher'))

            opt = {'task': 'integration_tests:ImageTeacher', 'datapath': data_path}
            for image_mode in ['resnet152', 'no_image_model']:
                opt['image_mode'] = image_mode
                _test_display_output(opt)
github facebookresearch / ParlAI / tests / test_build_data.py View on Github external
def test_download_multiprocess(self):
        urls = [
            'https://parl.ai/downloads/mnist/mnist.tar.gz',
            'https://parl.ai/downloads/mnist/mnist.tar.gz.BAD',
            'https://parl.ai/downloads/mnist/mnist.tar.gz.BAD',
        ]

        with testing_utils.capture_output() as stdout:
            download_results = build_data.download_multiprocess(
                urls, self.datapath, dest_filenames=self.dest_filenames
            )
        stdout = stdout.getvalue()

        output_filenames, output_statuses, output_errors = zip(*download_results)
        self.assertEqual(
            output_filenames,
            self.dest_filenames,
            f'output filenames not correct\n{stdout}',
        )
        self.assertEqual(
            output_statuses,
            (200, 403, 403),
            f'output http statuses not correct\n{stdout}',
        )
github facebookresearch / ParlAI / tests / test_fairseq.py View on Github external
def test_labelcands(self):
        stdout, valid, test = testing_utils.train_model(
            dict(
                task='integration_tests:candidate',
                model='fairseq',
                arch='lstm_wiseman_iwslt_de_en',
                lr=LR,
                batchsize=BATCH_SIZE,
                num_epochs=NUM_EPOCHS,
                rank_candidates=True,
                skip_generation=True,
            )
        )

        self.assertTrue(
            valid['hits@1'] > 0.95,
            "valid hits@1 = {}\nLOG:\n{}".format(valid['hits@1'], stdout),
        )
github facebookresearch / ParlAI / tests / nightly / gpu / test_wizard.py View on Github external
def test_retrieval(self):
        stdout, _, test = testing_utils.eval_model(RETRIEVAL_OPTIONS)
        self.assertGreaterEqual(
            test['accuracy'],
            0.86,
            'test acc = {}\nLOG:\n{}'.format(test['accuracy'], stdout),
        )
        self.assertGreaterEqual(
            test['hits@5'],
            0.98,
            'test hits@5 = {}\nLOG:\n{}'.format(test['hits@5'], stdout),
        )
        self.assertGreaterEqual(
            test['hits@10'],
            0.99,
            'test hits@10 = {}\nLOG:\n{}'.format(test['hits@10'], stdout),
        )
github facebookresearch / ParlAI / tests / test_dict.py View on Github external
"""
        # Download model, move to a new location
        datapath = ParlaiParser().parse_args(print_args=False)['datapath']
        try:
            # remove unittest models if there before
            shutil.rmtree(os.path.join(datapath, 'models/unittest'))
        except FileNotFoundError:
            pass
        testing_utils.download_unittest_models()

        zoo_path = 'zoo:unittest/seq2seq/model'
        model_path = modelzoo_path(datapath, zoo_path)
        os.remove(model_path + '.dict')
        # Test that eval model fails
        with self.assertRaises(RuntimeError):
            testing_utils.eval_model(dict(task='babi:task1k:1', model_file=model_path))
        try:
            # remove unittest models if there after
            shutil.rmtree(os.path.join(datapath, 'models/unittest'))
        except FileNotFoundError:
            pass
github facebookresearch / ParlAI / tests / nightly / gpu / test_wizard.py View on Github external
def setUpClass(cls):
        # go ahead and download things here
        with testing_utils.capture_output():
            parser = display_data.setup_args()
            parser.set_defaults(**END2END_OPTIONS)
            opt = parser.parse_args(print_args=False)
            opt['num_examples'] = 1
            display_data.display_data(opt)
github facebookresearch / ParlAI / tests / nightly / gpu / test_self_feeding.py View on Github external
def test_released_model(self):
        """
        Check the pretrained model produces correct results.
        """
        _, _, test = testing_utils.eval_model(
            {
                'model_file': 'zoo:self_feeding/hh131k_hb60k_fb60k_st1k/model',
                'task': 'self_feeding:all',
                'batchsize': 20,
            },
            skip_valid=True,
        )

        self.assertAlmostEqual(test['dia_acc'], 0.506, delta=0.001)
        self.assertAlmostEqual(test['fee_acc'], 0.744, delta=0.001)
        self.assertAlmostEqual(test['sat_f1'], 0.8343, delta=0.0001)
github facebookresearch / ParlAI / tests / test_torch_agent.py View on Github external
def get_agent(**kwargs):
    r"""
    Return opt-initialized agent.

    :param kwargs: any kwargs you want to set using parser.set_params(\*\*kwargs)
    """
    if 'no_cuda' not in kwargs:
        kwargs['no_cuda'] = True
    from parlai.core.params import ParlaiParser

    parser = ParlaiParser()
    MockTorchAgent.add_cmdline_args(parser)
    parser.set_params(**kwargs)
    opt = parser.parse_args(print_args=False)
    with testing_utils.capture_output():
        return MockTorchAgent(opt)
github facebookresearch / ParlAI / tests / test_transformers.py View on Github external
inference='beam',
                    beam_size=5,
                    **args,
                )
            )
            self.assertGreaterEqual(noblock_valid['f1'], 0.99)

            # first confirm all is good without blocking
            _, valid, test = testing_utils.eval_model(
                dict(beam_context_block_ngram=-1, **args)
            )
            self.assertGreaterEqual(valid['f1'], 0.99)
            self.assertGreaterEqual(valid['bleu-4'], 0.99)

            # there's a special case for block == 1
            _, valid, test = testing_utils.eval_model(
                dict(beam_context_block_ngram=1, **args)
            )
            # bleu and f1 should be totally wrecked.
            self.assertLess(valid['f1'], 0.01)
            self.assertLess(valid['bleu-4'], 0.01)

            # a couple general cases
            _, valid, test = testing_utils.eval_model(
                dict(beam_context_block_ngram=2, **args)
            )
            # should take a big hit here
            self.assertLessEqual(valid['f1'], noblock_valid['f1'])
            # bleu-1 should be relatively okay
            self.assertLessEqual(valid['bleu-1'], noblock_valid['bleu-1'])
            self.assertGreaterEqual(valid['bleu-1'], 0.50)
            # and bleu-2 should be 0 at this point