Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
caplog.set_level(logging.INFO)
dirpath = "temp_test_model"
Meta.reset()
emmental.init(dirpath)
def ce_loss(module_name, immediate_output_dict, Y, active):
return F.cross_entropy(
immediate_output_dict[module_name][0][active], (Y.view(-1))[active]
)
def output(module_name, immediate_output_dict):
return F.softmax(immediate_output_dict[module_name][0], dim=1)
task1 = EmmentalTask(
name="task_1",
module_pool=nn.ModuleDict(
{"m1": nn.Linear(10, 10, bias=False), "m2": nn.Linear(10, 2, bias=False)}
),
task_flow=[
{"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
{"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
],
loss_func=partial(ce_loss, "m2"),
output_func=partial(output, "m2"),
scorer=Scorer(metrics=["accuracy"]),
)
new_task1 = EmmentalTask(
name="task_1",
module_pool=nn.ModuleDict(
new_task1 = EmmentalTask(
name="task_1",
module_pool=nn.ModuleDict(
{"m1": nn.Linear(10, 5, bias=False), "m2": nn.Linear(5, 2, bias=False)}
),
task_flow=[
{"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
{"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
],
loss_func=partial(ce_loss, "m2"),
output_func=partial(output, "m2"),
scorer=Scorer(metrics=["accuracy"]),
)
task2 = EmmentalTask(
name="task_2",
module_pool=nn.ModuleDict(
{"m1": nn.Linear(10, 5, bias=False), "m2": nn.Linear(5, 2, bias=False)}
),
task_flow=[
{"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
{"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
],
loss_func=partial(ce_loss, "m2"),
output_func=partial(output, "m2"),
scorer=Scorer(metrics=["accuracy"]),
)
# Test w/ dataparallel
model = EmmentalModel(name="test", tasks=task1)
# Create task
def ce_loss(task_name, immediate_ouput_dict, Y, active):
module_name = f"{task_name}_pred_head"
return F.cross_entropy(
immediate_ouput_dict[module_name][0][active], (Y.view(-1))[active]
)
def output(task_name, immediate_ouput_dict):
module_name = f"{task_name}_pred_head"
return F.softmax(immediate_ouput_dict[module_name][0], dim=1)
task_metrics = {"task1": ["accuracy"], "task2": ["accuracy", "roc_auc"]}
tasks = [
EmmentalTask(
name=task_name,
module_pool=nn.ModuleDict(
{
"input_module": nn.Linear(2, 8),
f"{task_name}_pred_head": nn.Linear(8, 2),
}
),
task_flow=[
{
"name": "input",
"module": "input_module",
"inputs": [("_input_", "data")],
},
{
"name": f"{task_name}_pred_head",
"module": f"{task_name}_pred_head",
task1 = EmmentalTask(
name="task_1",
module_pool=nn.ModuleDict(
{"m1": nn.Linear(10, 10, bias=False), "m2": nn.Linear(10, 2, bias=False)}
),
task_flow=[
{"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
{"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
],
loss_func=partial(ce_loss, "m2"),
output_func=partial(output, "m2"),
scorer=Scorer(metrics=["accuracy"]),
)
new_task1 = EmmentalTask(
name="task_1",
module_pool=nn.ModuleDict(
{"m1": nn.Linear(10, 5, bias=False), "m2": nn.Linear(5, 2, bias=False)}
),
task_flow=[
{"name": "m1", "module": "m1", "inputs": [("_input_", "data")]},
{"name": "m2", "module": "m2", "inputs": [("m1", 0)]},
],
loss_func=partial(ce_loss, "m2"),
output_func=partial(output, "m2"),
scorer=Scorer(metrics=["accuracy"]),
)
task2 = EmmentalTask(
name="task_2",
module_pool=nn.ModuleDict(
r"""Build the MTL network using all tasks.
Args:
tasks(EmmentalTask or List[EmmentalTask]): A task or a list of tasks.
"""
if not isinstance(tasks, Iterable):
tasks = [tasks]
for task in tasks:
if task.name in self.task_names:
raise ValueError(
f"Found duplicate task {task.name}, different task should use "
f"different task name."
)
if not isinstance(task, EmmentalTask):
raise ValueError(f"Unrecognized task type {task}.")
self.add_task(task)
# Loss function
if pred_task_name in slice_distribution:
loss = partial(
utils.ce_loss,
pred_head_module_name,
weight=move_to_device(
slice_distribution[pred_task_name],
Meta.config["model_config"]["device"],
),
)
else:
loss = partial(utils.ce_loss, pred_head_module_name)
tasks.append(
EmmentalTask(
name=pred_task_name,
module_pool=pred_module_pool,
task_flow=pred_task_flow,
loss_func=loss,
output_func=partial(utils.output, pred_head_module_name),
scorer=task.scorer,
)
)
# Create master task
# Create task name
master_task_name = task.name
# Create attention module
master_attention_module_name = f"{master_task_name}_attention"
def add_task(self, task: EmmentalTask) -> None:
"""Add a single task into MTL network.
Args:
task: A task to add.
"""
if not isinstance(task, EmmentalTask):
raise ValueError(f"Unrecognized task type {task}.")
if task.name in self.task_names:
raise ValueError(
f"Found duplicate task {task.name}, different task should use "
f"different task name."
)
# Combine module_pool from all tasks
for key in task.module_pool.keys():
if key in self.module_pool.keys():
if Meta.config["model_config"]["dataparallel"]:
task.module_pool[key] = nn.DataParallel(self.module_pool[key])
else:
task.module_pool[key] = self.module_pool[key]
else:
"inputs": [
("_input_", "feature_index"),
("_input_", "feature_weight"),
],
},
{
"name": f"{task_name}_pred_head",
"module": f"{task_name}_pred_head",
"inputs": None,
},
]
else:
raise ValueError(f"Unrecognized model {model}.")
tasks.append(
EmmentalTask(
name=task_name,
module_pool=module_pool,
task_flow=task_flow,
loss_func=partial(loss, f"{task_name}_pred_head"),
output_func=partial(output, f"{task_name}_pred_head"),
scorer=Scorer(metrics=["accuracy", "precision", "recall", "f1"]),
)
)
return tasks