How to use the pyserini.trectools.TrecRun function in pyserini

To help you get started, we’ve selected a few pyserini examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github castorini / pyserini / integrations / test_simple_fusion_search_integration.py View on Github external
def test_simple_fusion_searcher(self):
        index_dirs = ['indexes/lucene-index-cord19-abstract-2020-05-01/',
                      'indexes/lucene-index-cord19-full-text-2020-05-01/',
                      'indexes/lucene-index-cord19-paragraph-2020-05-01/']

        searcher = SimpleFusionSearcher(index_dirs, method=FusionMethod.RRF)

        runs, topics = [], get_topics('covid_round2')
        for topic in tqdm(sorted(topics.keys())):
            query = topics[topic]['question'] + ' ' + topics[topic]['query']
            hits = searcher.search(query, k=10000, query_generator=None, strip_segment_id=True, remove_dups=True)
            docid_score_pair = [(hit.docid, hit.score) for hit in hits]
            run = TrecRun.from_search_results(docid_score_pair, topic=topic)
            runs.append(run)

        all_topics_run = TrecRun.concat(runs)
        all_topics_run.save_to_txt(output_path='runs/fused.txt', tag='reciprocal_rank_fusion_k=60')

        # Only keep topic, docid and rank. Scores have different floating point precisions.
        # TODO: We should probably do this in Python as opposed to calling out to shell for better portability.
        os.system("""awk '{print $1" "$3" "$4}' runs/fused.txt > runs/this.txt""")
        os.system("""awk '{print $1" "$3" "$4}' runs/anserini.covid-r2.fusion1.txt > runs/that.txt""")

        self.assertTrue(filecmp.cmp('runs/this.txt', 'runs/that.txt'))
github castorini / pyserini / tests / test_trectools.py View on Github external
def test_discard_qrels(self):
        run = TrecRun('tests/resources/simple_trec_run_filter.txt')
        qrels = Qrels('tools/topics-and-qrels/qrels.covid-round1.txt')

        run.discard_qrels(qrels, clone=False).save_to_txt(output_path=self.output_path)
        self.assertTrue(filecmp.cmp('tests/resources/simple_trec_run_remove_verify.txt', self.output_path))
github castorini / pyserini / tests / test_trectools.py View on Github external
def test_trec_run_read(self):
        input_path = 'tests/resources/simple_trec_run_read.txt'
        verify_path = 'tests/resources/simple_trec_run_read_verify.txt'

        run = TrecRun(filepath=input_path)
        run.save_to_txt(self.output_path)
        self.assertTrue(filecmp.cmp(verify_path, self.output_path))
github castorini / pyserini / tests / test_trectools.py View on Github external
def test_normalize_scores(self):
        run = TrecRun('tests/resources/simple_trec_run_fusion_1.txt')
        run.rescore(RescoreMethod.NORMALIZE).save_to_txt(self.output_path)
        self.assertTrue(filecmp.cmp('tests/resources/simple_trec_run_normalize_verify.txt', self.output_path))
github castorini / pyserini / integrations / test_simple_fusion_search_integration.py View on Github external
def test_simple_fusion_searcher(self):
        index_dirs = ['indexes/lucene-index-cord19-abstract-2020-05-01/',
                      'indexes/lucene-index-cord19-full-text-2020-05-01/',
                      'indexes/lucene-index-cord19-paragraph-2020-05-01/']

        searcher = SimpleFusionSearcher(index_dirs, method=FusionMethod.RRF)

        runs, topics = [], get_topics('covid_round2')
        for topic in tqdm(sorted(topics.keys())):
            query = topics[topic]['question'] + ' ' + topics[topic]['query']
            hits = searcher.search(query, k=10000, query_generator=None, strip_segment_id=True, remove_dups=True)
            docid_score_pair = [(hit.docid, hit.score) for hit in hits]
            run = TrecRun.from_search_results(docid_score_pair, topic=topic)
            runs.append(run)

        all_topics_run = TrecRun.concat(runs)
        all_topics_run.save_to_txt(output_path='runs/fused.txt', tag='reciprocal_rank_fusion_k=60')

        # Only keep topic, docid and rank. Scores have different floating point precisions.
        # TODO: We should probably do this in Python as opposed to calling out to shell for better portability.
        os.system("""awk '{print $1" "$3" "$4}' runs/fused.txt > runs/this.txt""")
        os.system("""awk '{print $1" "$3" "$4}' runs/anserini.covid-r2.fusion1.txt > runs/that.txt""")

        self.assertTrue(filecmp.cmp('runs/this.txt', 'runs/that.txt'))
github castorini / pyserini / tests / test_trectools.py View on Github external
def test_retain_qrels(self):
        run = TrecRun('tests/resources/simple_trec_run_filter.txt')
        qrels = Qrels('tools/topics-and-qrels/qrels.covid-round1.txt')

        run.retain_qrels(qrels, clone=True).save_to_txt(output_path=self.output_path)
        self.assertTrue(filecmp.cmp('tests/resources/simple_trec_run_keep_verify.txt', self.output_path))
github castorini / pyserini / pyserini / fusion / _base.py View on Github external
depth : int
        Maximum number of results from each input run to consider. Set to ``None`` by default, which indicates that
        the complete list of results is considered.
    k : int
        Length of final results list.  Set to ``None`` by default, which indicates that the union of all input documents
        are ranked.

    Returns
    -------
    TrecRun
        Output ``TrecRun`` that combines input runs via reciprocal rank fusion.
    """

    # TODO: Add option to *not* clone runs, thus making the method destructive, but also more efficient.
    rrf_runs = [run.clone().rescore(method=RescoreMethod.RRF, rrf_k=rrf_k) for run in runs]
    return TrecRun.merge(rrf_runs, AggregationMethod.SUM, depth=depth, k=k)
github castorini / pyserini / pyserini / fusion / _base.py View on Github external
are ranked.

    Returns
    -------
    TrecRun
        Output ``TrecRun`` that combines input runs via interpolation.
    """

    if len(runs) != 2:
        raise Exception('Interpolation must be performed on exactly two runs.')

    scaled_runs = []
    scaled_runs.append(runs[0].clone().rescore(method=RescoreMethod.SCALE, scale=alpha))
    scaled_runs.append(runs[1].clone().rescore(method=RescoreMethod.SCALE, scale=(1-alpha)))

    return TrecRun.merge(scaled_runs, AggregationMethod.SUM, depth=depth, k=k)
github castorini / pyserini / pyserini / fusion / __main__.py View on Github external
parser = argparse.ArgumentParser(description='Perform various ways of fusion given a list of trec run files.')
parser.add_argument('--runs', type=str, nargs='+', default=[], required=True,
                    help='List of run files separated by space.')
parser.add_argument('--output', type=str, required=True, help="Path to resulting fused txt.")
parser.add_argument('--runtag', type=str, default="pyserini.fusion", help="Tag name of fused run.")
parser.add_argument('--method', type=FusionMethod, default=FusionMethod.RRF, help="The fusion method to be used.")
parser.add_argument('--rrf.k', dest='rrf_k', type=int, default=60,
                    help="Parameter k needed for reciprocal rank fusion.")
parser.add_argument('--alpha', type=float, default=0.5, required=False, help='Alpha value used for interpolation.')
parser.add_argument('--depth', type=int, default=1000, required=False, help='Pool depth per topic.')
parser.add_argument('--k', type=int, default=1000, required=False, help='Number of documents to output per topic.')
args = parser.parse_args()

trec_runs = [TrecRun(filepath=path) for path in args.runs]

fused_run = None
if args.method == FusionMethod.RRF:
    fused_run = reciprocal_rank_fusion(trec_runs, rrf_k=args.rrf_k, depth=args.depth, k=args.k)
elif args.method == FusionMethod.INTERPOLATION:
    fused_run = interpolation(trec_runs, alpha=args.alpha, depth=args.depth, k=args.k)
elif args.method == FusionMethod.AVERAGE:
    fused_run = average(trec_runs, depth=args.depth, k=args.k)
else:
    raise NotImplementedError(f'Fusion method {args.method} not implemented.')

fused_run.save_to_txt(args.output, tag=args.runtag)
github castorini / pyserini / pyserini / search / _searcher.py View on Github external
def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None, strip_segment_id=False, remove_dups=False) -> List[JSimpleSearcherResult]:
        trec_runs, docid_to_search_result = list(), dict()

        for searcher in self.searchers:
            docid_score_pair = list()
            hits = searcher.search(q, k=k, query_generator=query_generator,
                                   strip_segment_id=strip_segment_id, remove_dups=remove_dups)

            for hit in hits:
                docid_to_search_result[hit.docid] = hit
                docid_score_pair.append((hit.docid, hit.score))

            run = TrecRun.from_search_results(docid_score_pair)
            trec_runs.append(run)

        if self.method == FusionMethod.RRF:
            fused_run = reciprocal_rank_fusion(trec_runs, rrf_k=60, depth=1000, k=1000)
        else:
            raise NotImplementedError()

        return SimpleFusionSearcher.convert_to_search_result(fused_run, docid_to_search_result)