How to use the emmental.metrics.precision.precision_scorer function in emmental

To help you get started, we’ve selected a few emmental examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SenWu / emmental / tests / metrics / test_metrics.py View on Github external
def test_precision(caplog):
    """Unit test of precision_scorer."""
    caplog.set_level(logging.INFO)

    metric_dict = precision_scorer(GOLDS, PROBS, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(GOLDS, None, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(GOLDS, None, PREDS, pos_label=0)
    assert isequal(metric_dict, {"precision": 0.6})

    metric_dict = precision_scorer(PROB_GOLDS, PROBS, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(PROB_GOLDS, None, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(PROB_GOLDS, None, PREDS, pos_label=0)
    assert isequal(metric_dict, {"precision": 0.6})
github SenWu / emmental / tests / metrics / test_metrics.py View on Github external
def test_precision(caplog):
    """Unit test of precision_scorer."""
    caplog.set_level(logging.INFO)

    metric_dict = precision_scorer(GOLDS, PROBS, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(GOLDS, None, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(GOLDS, None, PREDS, pos_label=0)
    assert isequal(metric_dict, {"precision": 0.6})

    metric_dict = precision_scorer(PROB_GOLDS, PROBS, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(PROB_GOLDS, None, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(PROB_GOLDS, None, PREDS, pos_label=0)
    assert isequal(metric_dict, {"precision": 0.6})
github SenWu / emmental / tests / metrics / test_metrics.py View on Github external
def test_precision(caplog):
    """Unit test of precision_scorer."""
    caplog.set_level(logging.INFO)

    metric_dict = precision_scorer(GOLDS, PROBS, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(GOLDS, None, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(GOLDS, None, PREDS, pos_label=0)
    assert isequal(metric_dict, {"precision": 0.6})

    metric_dict = precision_scorer(PROB_GOLDS, PROBS, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(PROB_GOLDS, None, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(PROB_GOLDS, None, PREDS, pos_label=0)
    assert isequal(metric_dict, {"precision": 0.6})
github SenWu / emmental / tests / metrics / test_metrics.py View on Github external
"""Unit test of precision_scorer."""
    caplog.set_level(logging.INFO)

    metric_dict = precision_scorer(GOLDS, PROBS, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(GOLDS, None, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(GOLDS, None, PREDS, pos_label=0)
    assert isequal(metric_dict, {"precision": 0.6})

    metric_dict = precision_scorer(PROB_GOLDS, PROBS, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(PROB_GOLDS, None, PREDS, pos_label=1)
    assert isequal(metric_dict, {"precision": 1})

    metric_dict = precision_scorer(PROB_GOLDS, None, PREDS, pos_label=0)
    assert isequal(metric_dict, {"precision": 0.6})
github SenWu / emmental / src / emmental / metrics / __init__.py View on Github external
from emmental.metrics.fbeta import f1_scorer, fbeta_scorer
from emmental.metrics.matthews_correlation import (
    matthews_correlation_coefficient_scorer,
)
from emmental.metrics.mean_squared_error import mean_squared_error_scorer
from emmental.metrics.pearson_correlation import pearson_correlation_scorer
from emmental.metrics.pearson_spearman import pearson_spearman_scorer
from emmental.metrics.precision import precision_scorer
from emmental.metrics.recall import recall_scorer
from emmental.metrics.roc_auc import roc_auc_scorer
from emmental.metrics.spearman_correlation import spearman_correlation_scorer

METRICS = {
    "accuracy": accuracy_scorer,
    "accuracy_f1": accuracy_f1_scorer,
    "precision": precision_scorer,
    "recall": recall_scorer,
    "f1": f1_scorer,
    "fbeta": fbeta_scorer,
    "matthews_correlation": matthews_correlation_coefficient_scorer,
    "mean_squared_error": mean_squared_error_scorer,
    "pearson_correlation": pearson_correlation_scorer,
    "pearson_spearman": pearson_spearman_scorer,
    "spearman_correlation": spearman_correlation_scorer,
    "roc_auc": roc_auc_scorer,
}

__all__ = [
    "accuracy_scorer",
    "accuracy_f1_scorer",
    "f1_scorer",
    "fbeta_scorer",
github SenWu / emmental / src / emmental / metrics / fbeta.py View on Github external
Args:
      golds: Ground truth values.
      probs: Predicted probabilities.
      preds: Predicted values.
      uids: Unique ids, defaults to None.
      pos_label: The positive class label, defaults to 1.
      beta: Weight of precision in harmonic mean, defaults to 1.

    Returns:
      F-beta score.
    """
    # Convert probabilistic label to hard label
    if len(golds.shape) == 2:
        golds = prob_to_pred(golds)

    precision = precision_scorer(golds, probs, preds, uids, pos_label)["precision"]
    recall = recall_scorer(golds, probs, preds, uids, pos_label)["recall"]

    fbeta = (
        (1 + beta ** 2) * (precision * recall) / ((beta ** 2 * precision) + recall)
        if (beta ** 2 * precision) + recall > 0
        else 0.0
    )

    return {f"f{beta}": fbeta}