How to use the emmental.meta.Meta.config function in emmental

To help you get started, we’ve selected a few emmental examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SenWu / emmental / src / emmental / model.py View on Github external
def flow(self, X_dict: Dict[str, Any], task_names: List[str]) -> Dict[str, Any]:
        """Forward based on input and task flow.

        Note:
          We assume that all shared modules from all tasks are based on the
          same input.

        Args:
          X_dict: The input data
          task_names: The task names that needs to forward.

        Returns:
          The output of all forwarded modules
        """
        X_dict = move_to_device(X_dict, Meta.config["model_config"]["device"])

        output_dict = dict(_input_=X_dict)

        # Call forward for each task
        for task_name in task_names:
            for action in self.task_flows[task_name]:
                if action["name"] not in output_dict:
                    if action["inputs"]:
                        try:
                            input = [
                                output_dict[action_name][output_index]
                                for action_name, output_index in action["inputs"]
                            ]
                        except Exception:
                            raise ValueError(f"Unrecognized action {action}.")
                        output = self.module_pool[action["module"]].forward(*input)
github SenWu / emmental / src / emmental / logging / tensorboard_writer.py View on Github external
def write_config(self, config_filename: str = "config.yaml") -> None:
        """Write the config to tensorboard and dump it to file.

        Args:
          config_filename: The config filename, defaults to "config.yaml".
        """
        config = json.dumps(Meta.config)
        self.writer.add_text(tag="config", text_string=config)

        super().write_config(config_filename)
github SenWu / emmental / src / emmental / contrib / slicing / task.py View on Github external
{
                    "name": pred_head_module_name,
                    "module": shared_pred_head_module_name,
                    "inputs": [(pred_transform_module_name, 0)],
                },
            ]
        )

        # Loss function
        if pred_task_name in slice_distribution:
            loss = partial(
                utils.ce_loss,
                pred_head_module_name,
                weight=move_to_device(
                    slice_distribution[pred_task_name],
                    Meta.config["model_config"]["device"],
                ),
            )
        else:
            loss = partial(utils.ce_loss, pred_head_module_name)

        tasks.append(
            EmmentalTask(
                name=pred_task_name,
                module_pool=pred_module_pool,
                task_flow=pred_task_flow,
                loss_func=loss,
                output_func=partial(utils.output, pred_head_module_name),
                scorer=task.scorer,
            )
        )
github SenWu / emmental / src / emmental / task.py View on Github external
loss_func: Callable,
        output_func: Callable,
        scorer: Scorer,
        weight: Union[float, int] = 1.0,
    ) -> None:
        """Initialize EmmentalTask."""
        self.name = name
        assert isinstance(module_pool, nn.ModuleDict) is True
        self.module_pool = module_pool
        self.task_flow = task_flow
        self.loss_func = loss_func
        self.output_func = output_func
        self.scorer = scorer
        self.weight = weight

        if Meta.config["meta_config"]["verbose"]:
            logger.info(f"Created task: {self.name}")
github SenWu / emmental / src / emmental / data.py View on Github external
# Only merge list of tensors
        if isinstance(values[0], Tensor):
            item_tensor, item_mask_tensor = list_to_tensor(
                values,
                min_len=Meta.config["data_config"]["min_data_len"],
                max_len=Meta.config["data_config"]["max_data_len"],
            )
            X_batch[field_name] = item_tensor
            if item_mask_tensor is not None:
                X_batch[f"{field_name}_mask"] = item_mask_tensor

    for label_name, values in Y_batch.items():
        Y_batch[label_name] = list_to_tensor(
            values,
            min_len=Meta.config["data_config"]["min_data_len"],
            max_len=Meta.config["data_config"]["max_data_len"],
        )[0]

    return dict(X_batch), dict(Y_batch)
github SenWu / emmental / src / emmental / model.py View on Github external
"""

        uid_dict: Dict[str, List[str]] = defaultdict(list)
        loss_dict: Dict[str, ndarray] = defaultdict(float)
        gold_dict: Dict[str, ndarray] = defaultdict(list)
        prob_dict: Dict[str, ndarray] = defaultdict(list)

        output_dict = self.flow(X_dict, list(task_to_label_dict.keys()))

        # Calculate loss for each task
        for task_name, label_name in task_to_label_dict.items():
            Y = Y_dict[label_name]

            # Select the active samples
            if Meta.config["learner_config"]["ignore_index"] is not None:
                if len(Y.size()) == 1:
                    active = Y.detach() != Meta.config["learner_config"]["ignore_index"]
                else:
                    active = torch.any(
                        Y.detach() != Meta.config["learner_config"]["ignore_index"],
                        dim=1,
                    )
            else:
                active = torch.ByteTensor([True] * Y.size()[0])

            # Only calculate the loss when active example exists
            if active.any():
                uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]

                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
github SenWu / emmental / src / emmental / logging / checkpointer.py View on Github external
)

        # Set up checkpoint unit
        self.checkpoint_unit = Meta.config["logging_config"]["counter_unit"]

        logger.info(
            f"Save checkpoints at {self.checkpoint_path} every "
            f"{self.checkpoint_freq} {self.checkpoint_unit}"
        )

        # Set up checkpoint metric
        self.checkpoint_metric = Meta.config["logging_config"]["checkpointer_config"][
            "checkpoint_metric"
        ]

        self.checkpoint_all_metrics = Meta.config["logging_config"][
            "checkpointer_config"
        ]["checkpoint_task_metrics"]

        # Collect all metrics to checkpoint
        if self.checkpoint_all_metrics is None:
            self.checkpoint_all_metrics = dict()

        if self.checkpoint_metric:
            self.checkpoint_all_metrics.update(self.checkpoint_metric)

        # Check evaluation metric mode
        for metric, mode in self.checkpoint_all_metrics.items():
            if mode not in ["min", "max"]:
                raise ValueError(
                    f"Unrecognized checkpoint metric mode {mode} for metric {metric}, "
                    f"must be 'min' or 'max'."