How to use the emmental.utils.utils.prob_to_pred function in emmental

To help you get started, we’ve selected a few emmental examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SenWu / emmental / src / emmental / learner.py View on Github external
if total_count > 0:
            total_loss = sum(self.running_losses.values())
            metric_dict["model/all/train/loss"] = total_loss / total_count

        if calc_running_scores:
            micro_score_dict: Dict[str, List[ndarray]] = defaultdict(list)
            macro_score_dict: Dict[str, List[ndarray]] = defaultdict(list)

            # Calculate training metric
            for identifier in self.running_uids.keys():
                task_name, data_name, split = identifier.split("/")

                metric_score = model.scorers[task_name].score(
                    self.running_golds[identifier],
                    self.running_probs[identifier],
                    prob_to_pred(self.running_probs[identifier]),
                    self.running_uids[identifier],
                )
                for metric_name, metric_value in metric_score.items():
                    metric_dict[f"{identifier}/{metric_name}"] = metric_value

                # Collect average score
                identifier = construct_identifier(
                    task_name, data_name, split, "average"
                )

                metric_dict[identifier] = np.mean(list(metric_score.values()))

                micro_score_dict[split].extend(list(metric_score.values()))
                macro_score_dict[split].append(metric_dict[identifier])

            # Collect split-wise micro/macro average score
github SenWu / emmental / src / emmental / metrics / accuracy.py View on Github external
Args:
      golds: Ground truth values.
      probs: Predicted probabilities.
      preds: Predicted values.
      uids: Unique ids, defaults to None.
      normalize: Normalize the results or not, defaults to True.
      topk: Top K accuracy, defaults to 1.

    Returns:
      Accuracy, if normalize is True, return the fraction of correctly predicted
      samples (float), else returns the number of correctly predicted samples (int).
    """
    # Convert probabilistic label to hard label
    if len(golds.shape) == 2:
        golds = prob_to_pred(golds)

    if topk == 1 and preds is not None:
        n_matches = np.where(golds == preds)[0].shape[0]
    else:
        topk_preds = probs.argsort(axis=1)[:, -topk:][:, ::-1]
        n_matches = np.logical_or.reduce(
            topk_preds == golds.reshape(-1, 1), axis=1
        ).sum()

    if normalize:
        return {
            "accuracy" if topk == 1 else f"accuracy@{topk}": n_matches / golds.shape[0]
        }
    else:
        return {"accuracy" if topk == 1 else f"accuracy@{topk}": n_matches}
github SenWu / emmental / src / emmental / metrics / roc_auc.py View on Github external
preds: Predicted values.
      uids: Unique ids, defaults to None.
      pos_label: The positive class label, defaults to 1.

    Returns:
      ROC AUC score.
    """
    if len(probs.shape) == 2 and probs.shape[1] == 1:
        probs = probs.reshape(probs.shape[0])

    if len(golds.shape) == 2 and golds.shape[1] == 1:
        golds = golds.reshape(golds.shape[0])

    if len(probs.shape) > 1:
        if len(golds.shape) > 1:
            golds = pred_to_prob(prob_to_pred(golds), n_classes=probs.shape[1])
        else:
            golds = pred_to_prob(golds, n_classes=probs.shape[1])
    else:
        if len(golds.shape) > 1:
            golds = prob_to_pred(golds)

    try:
        roc_auc = roc_auc_score(golds, probs)
    except ValueError:
        logger.warning(
            "Only one class present in golds."
            "ROC AUC score is not defined in that case, set as nan instead."
        )
        roc_auc = float("nan")

    return {"roc_auc": roc_auc}
github SenWu / emmental / src / emmental / metrics / matthews_correlation.py View on Github external
uids: Optional[List[str]] = None,
) -> Dict[str, float]:
    """Matthews correlation coefficient (MCC).

    Args:
      golds: Ground truth values.
      probs: Predicted probabilities.
      preds: Predicted values.
      uids: Unique ids, defaults to None.

    Returns:
      Matthews correlation coefficient score.
    """
    # Convert probabilistic label to hard label
    if len(golds.shape) == 2:
        golds = prob_to_pred(golds)

    return {"matthews_corrcoef": matthews_corrcoef(golds, preds)}
github SenWu / emmental / src / emmental / model.py View on Github external
# Calculate average loss
        for task_name in uid_dict.keys():
            if not isinstance(loss_dict[task_name], list):
                loss_dict[task_name] /= len(uid_dict[task_name])

        res = {
            "uids": uid_dict,
            "golds": gold_dict,
            "probs": prob_dict,
            "losses": loss_dict,
        }

        if return_preds:
            for task_name, prob in prob_dict.items():
                pred_dict[task_name] = prob_to_pred(prob)
            res["preds"] = pred_dict

        return res
github SenWu / emmental / src / emmental / model.py View on Github external
)

        # Calculate average loss
        for task_name in uid_dict.keys():
            loss_dict[task_name] /= len(uid_dict[task_name])

        res = {
            "uids": uid_dict,
            "golds": gold_dict,
            "probs": prob_dict,
            "losses": loss_dict,
        }

        if return_preds:
            for task_name, prob in prob_dict.items():
                pred_dict[task_name] = prob_to_pred(prob)
            res["preds"] = pred_dict

        return res
github SenWu / emmental / src / emmental / metrics / recall.py View on Github external
) -> Dict[str, float]:
    """Recall.

    Args:
      golds: Ground truth values.
      probs: Predicted probabilities.
      preds: Predicted values.
      uids: Unique ids, defaults to None.
      pos_label: The positive class label, defaults to 1.

    Returns:
      Recall.
    """
    # Convert probabilistic label to hard label
    if len(golds.shape) == 2:
        golds = prob_to_pred(golds)

    pred_pos = np.where(preds == pos_label, True, False)
    gt_pos = np.where(golds == pos_label, True, False)
    TP = np.sum(pred_pos * gt_pos)
    FN = np.sum(np.logical_not(pred_pos) * gt_pos)

    recall = TP / (TP + FN) if TP + FN > 0 else 0.0

    return {"recall": recall}
github SenWu / emmental / src / emmental / metrics / fbeta.py View on Github external
"""F-beta score is the weighted harmonic mean of precision and recall.

    Args:
      golds: Ground truth values.
      probs: Predicted probabilities.
      preds: Predicted values.
      uids: Unique ids, defaults to None.
      pos_label: The positive class label, defaults to 1.
      beta: Weight of precision in harmonic mean, defaults to 1.

    Returns:
      F-beta score.
    """
    # Convert probabilistic label to hard label
    if len(golds.shape) == 2:
        golds = prob_to_pred(golds)

    precision = precision_scorer(golds, probs, preds, uids, pos_label)["precision"]
    recall = recall_scorer(golds, probs, preds, uids, pos_label)["recall"]

    fbeta = (
        (1 + beta ** 2) * (precision * recall) / ((beta ** 2 * precision) + recall)
        if (beta ** 2 * precision) + recall > 0
        else 0.0
    )

    return {f"f{beta}": fbeta}
github SenWu / emmental / src / emmental / metrics / precision.py View on Github external
) -> Dict[str, float]:
    """Precision.

    Args:
      golds: Ground truth values.
      probs: Predicted probabilities.
      preds: Predicted values.
      uids: Unique ids, defaults to None.
      pos_label: The positive class label, defaults to 1.

    Returns:
      Precision.
    """
    # Convert probabilistic label to hard label
    if len(golds.shape) == 2:
        golds = prob_to_pred(golds)

    pred_pos = np.where(preds == pos_label, True, False)
    gt_pos = np.where(golds == pos_label, True, False)
    TP = np.sum(pred_pos * gt_pos)
    FP = np.sum(pred_pos * np.logical_not(gt_pos))

    precision = TP / (TP + FP) if TP + FP > 0 else 0.0

    return {"precision": precision}
github SenWu / emmental / src / emmental / metrics / roc_auc.py View on Github external
ROC AUC score.
    """
    if len(probs.shape) == 2 and probs.shape[1] == 1:
        probs = probs.reshape(probs.shape[0])

    if len(golds.shape) == 2 and golds.shape[1] == 1:
        golds = golds.reshape(golds.shape[0])

    if len(probs.shape) > 1:
        if len(golds.shape) > 1:
            golds = pred_to_prob(prob_to_pred(golds), n_classes=probs.shape[1])
        else:
            golds = pred_to_prob(golds, n_classes=probs.shape[1])
    else:
        if len(golds.shape) > 1:
            golds = prob_to_pred(golds)

    try:
        roc_auc = roc_auc_score(golds, probs)
    except ValueError:
        logger.warning(
            "Only one class present in golds."
            "ROC AUC score is not defined in that case, set as nan instead."
        )
        roc_auc = float("nan")

    return {"roc_auc": roc_auc}