How to use the simpletransformers.question_answering.question_answering_utils.RawResultExtended function in simpletransformers

To help you get started, we’ve selected a few simpletransformers examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github ThilinaRajapakse / simpletransformers / simpletransformers / question_answering / question_answering_model.py View on Github external
inputs['token_type_ids'] = None if args['model_type'] == 'xlm' else batch[2]

                example_indices = batch[3]

                if args['model_type'] in ['xlnet', 'xlm']:
                    inputs.update({'cls_index': batch[4],
                                   'p_mask':       batch[5]})

                outputs = model(**inputs)

                for i, example_index in enumerate(example_indices):
                    eval_feature = features[example_index.item()]
                    unique_id = int(eval_feature.unique_id)
                    if args['model_type'] in ['xlnet', 'xlm']:
                        # XLNet uses a more complex post-processing procedure
                        result = RawResultExtended(unique_id=unique_id,
                                                   start_top_log_probs=to_list(outputs[0][i]),
                                                   start_top_index=to_list(outputs[1][i]),
                                                   end_top_log_probs=to_list(outputs[2][i]),
                                                   end_top_index=to_list(outputs[3][i]),
                                                   cls_logits=to_list(outputs[4][i]))
                    else:
                        result = RawResult(unique_id=unique_id,
                                           start_logits=to_list(outputs[0][i]),
                                           end_logits=to_list(outputs[1][i]))
                    all_results.append(result)

        if args['model_type'] in ['xlnet', 'xlm']:
            answers = get_best_predictions_extended(examples, features, all_results, n_best_size,
                                                    args['max_answer_length'], model.config.start_n_top, model.config.end_n_top, True, tokenizer, args['null_score_diff_threshold'])
        else:
            answers = get_best_predictions(examples, features, all_results, n_best_size, args['max_answer_length'], False, False, True, False)
github ThilinaRajapakse / simpletransformers / simpletransformers / question_answering / question_answering_model.py View on Github external
inputs['token_type_ids'] = None if args['model_type'] == 'xlm' else batch[2]

                example_indices = batch[3]

                if args['model_type'] in ['xlnet', 'xlm']:
                    inputs.update({'cls_index': batch[4],
                                   'p_mask':       batch[5]})

                outputs = model(**inputs)

                for i, example_index in enumerate(example_indices):
                    eval_feature = features[example_index.item()]
                    unique_id = int(eval_feature.unique_id)
                    if args['model_type'] in ['xlnet', 'xlm']:
                        # XLNet uses a more complex post-processing procedure
                        result = RawResultExtended(unique_id=unique_id,
                                                   start_top_log_probs=to_list(outputs[0][i]),
                                                   start_top_index=to_list(outputs[1][i]),
                                                   end_top_log_probs=to_list(outputs[2][i]),
                                                   end_top_index=to_list(outputs[3][i]),
                                                   cls_logits=to_list(outputs[4][i]))
                    else:
                        result = RawResult(unique_id=unique_id,
                                           start_logits=to_list(outputs[0][i]),
                                           end_logits=to_list(outputs[1][i]))
                    all_results.append(result)

        prefix = 'test'
        if not os.path.isdir(output_dir):
            os.mkdir(output_dir)

        output_prediction_file = os.path.join(output_dir, "predictions_{}.json".format(prefix))

simpletransformers

An easy-to-use wrapper library for the Transformers library.

Apache-2.0
Latest version published 6 months ago

Package Health Score

70 / 100
Full package analysis

Similar packages