How to use the emmental.utils.utils.move_to_device function in emmental

To help you get started, we’ve selected a few emmental examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SenWu / emmental / src / emmental / model.py View on Github external
Y.detach() != Meta.config["learner_config"]["ignore_index"],
                        dim=1,
                    )
            else:
                active = torch.BoolTensor([True] * Y.size()[0])  # type: ignore

            # Only calculate the loss when active example exists
            if active.any():
                uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]

                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
                    move_to_device(
                        Y_dict[label_name], Meta.config["model_config"]["device"]
                    ),
                    move_to_device(active, Meta.config["model_config"]["device"]),
                )

                prob_dict[task_name] = (
                    self.output_funcs[task_name](output_dict)[
                        move_to_device(active, Meta.config["model_config"]["device"])
                    ]
                    .cpu()
                    .detach()
                    .numpy()
                )

                gold_dict[task_name] = Y_dict[label_name][active].cpu().numpy()

        return uid_dict, loss_dict, prob_dict, gold_dict
github SenWu / emmental / src / emmental / model.py View on Github external
active = Y.detach() != Meta.config["learner_config"]["ignore_index"]
                else:
                    active = torch.any(
                        Y.detach() != Meta.config["learner_config"]["ignore_index"],
                        dim=1,
                    )
            else:
                active = torch.ByteTensor([True] * Y.size()[0])

            # Only calculate the loss when active example exists
            if active.any():
                uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]

                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
                    move_to_device(
                        Y_dict[label_name], Meta.config["model_config"]["device"]
                    ),
                    move_to_device(active, Meta.config["model_config"]["device"]),
                )

                prob_dict[task_name] = (
                    self.output_funcs[task_name](output_dict)[
                        move_to_device(active, Meta.config["model_config"]["device"])
                    ]
                    .cpu()
                    .detach()
                    .numpy()
                )

                gold_dict[task_name] = Y_dict[label_name][active].cpu().numpy()
github SenWu / emmental / src / emmental / model.py View on Github external
Y.detach() != Meta.config["learner_config"]["ignore_index"],
                        dim=1,
                    )
            else:
                active = torch.ByteTensor([True] * Y.size()[0])

            # Only calculate the loss when active example exists
            if active.any():
                uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]

                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
                    move_to_device(
                        Y_dict[label_name], Meta.config["model_config"]["device"]
                    ),
                    move_to_device(active, Meta.config["model_config"]["device"]),
                )

                prob_dict[task_name] = (
                    self.output_funcs[task_name](output_dict)[
                        move_to_device(active, Meta.config["model_config"]["device"])
                    ]
                    .cpu()
                    .detach()
                    .numpy()
                )

                gold_dict[task_name] = Y_dict[label_name][active].cpu().numpy()

        return uid_dict, loss_dict, prob_dict, gold_dict
github SenWu / emmental / src / emmental / model.py View on Github external
active = Y.detach() != Meta.config["learner_config"]["ignore_index"]
                else:
                    active = torch.any(
                        Y.detach() != Meta.config["learner_config"]["ignore_index"],
                        dim=1,
                    )
            else:
                active = torch.BoolTensor([True] * Y.size()[0])  # type: ignore

            # Only calculate the loss when active example exists
            if active.any():
                uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]

                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
                    move_to_device(
                        Y_dict[label_name], Meta.config["model_config"]["device"]
                    ),
                    move_to_device(active, Meta.config["model_config"]["device"]),
                )

                prob_dict[task_name] = (
                    self.output_funcs[task_name](output_dict)[
                        move_to_device(active, Meta.config["model_config"]["device"])
                    ]
                    .cpu()
                    .detach()
                    .numpy()
                )

                gold_dict[task_name] = Y_dict[label_name][active].cpu().numpy()
github SenWu / emmental / src / emmental / contrib / slicing / task.py View on Github external
],
                },
                {
                    "name": ind_head_module_name,
                    "module": ind_head_module_name,
                    "inputs": [(ind_head_dropout_module_name, 0)],
                },
            ]
        )

        # Loss function
        if ind_task_name in slice_distribution:
            loss = partial(
                utils.ce_loss,
                ind_head_module_name,
                weight=move_to_device(
                    slice_distribution[ind_task_name],
                    Meta.config["model_config"]["device"],
                ),
            )
        else:
            loss = partial(utils.ce_loss, ind_head_module_name)

        tasks.append(
            EmmentalTask(
                name=ind_task_name,
                module_pool=ind_module_pool,
                task_flow=ind_task_flow,
                loss_func=loss,
                output_func=partial(utils.output, ind_head_module_name),
                scorer=slice_scorer,
            )
github SenWu / emmental / src / emmental / model.py View on Github external
# Only calculate the loss when active example exists
            if active.any():
                uid_dict[task_name] = [*itertools.compress(uids, active.numpy())]

                loss_dict[task_name] = self.loss_funcs[task_name](
                    output_dict,
                    move_to_device(
                        Y_dict[label_name], Meta.config["model_config"]["device"]
                    ),
                    move_to_device(active, Meta.config["model_config"]["device"]),
                )

                prob_dict[task_name] = (
                    self.output_funcs[task_name](output_dict)[
                        move_to_device(active, Meta.config["model_config"]["device"])
                    ]
                    .cpu()
                    .detach()
                    .numpy()
                )

                gold_dict[task_name] = Y_dict[label_name][active].cpu().numpy()

        return uid_dict, loss_dict, prob_dict, gold_dict