How to use the simpletransformers.question_answering.question_answering_utils.write_predictions_extended function in simpletransformers

To help you get started, we’ve selected a few simpletransformers examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github ThilinaRajapakse / simpletransformers / simpletransformers / question_answering / question_answering_model.py View on Github external
result = RawResult(unique_id=unique_id,
                                           start_logits=to_list(outputs[0][i]),
                                           end_logits=to_list(outputs[1][i]))
                    all_results.append(result)

        prefix = 'test'
        if not os.path.isdir(output_dir):
            os.mkdir(output_dir)

        output_prediction_file = os.path.join(output_dir, "predictions_{}.json".format(prefix))
        output_nbest_file = os.path.join(output_dir, "nbest_predictions_{}.json".format(prefix))
        output_null_log_odds_file = os.path.join(output_dir, "null_odds_{}.json".format(prefix))

        if args['model_type'] in ['xlnet', 'xlm']:
            # XLNet uses a more complex post-processing procedure
            all_predictions, all_nbest_json, scores_diff_json = write_predictions_extended(examples, features, all_results, args['n_best_size'],
                                                                                           args['max_answer_length'], output_prediction_file,
                                                                                           output_nbest_file, output_null_log_odds_file, eval_data,
                                                                                           model.config.start_n_top, model.config.end_n_top,
                                                                                           True, tokenizer, not args['silent'])
        else:
            all_predictions, all_nbest_json, scores_diff_json = write_predictions(examples, features, all_results, args['n_best_size'],
                                                                                  args['max_answer_length'], False, output_prediction_file,
                                                                                  output_nbest_file, output_null_log_odds_file, not args['silent'],
                                                                                  True, args['null_score_diff_threshold'])

        return all_predictions, all_nbest_json, scores_diff_json

simpletransformers

An easy-to-use wrapper library for the Transformers library.

Apache-2.0
Latest version published 7 months ago

Package Health Score

65 / 100
Full package analysis

Similar packages