Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_meta(caplog):
"""Unit test of meta."""
caplog.set_level(logging.INFO)
dirpath = "temp_test_meta_log_folder"
Meta.reset()
emmental.init(dirpath)
# Check the log folder is created correctly
assert os.path.isdir(dirpath) is True
assert Meta.log_path.startswith(dirpath) is True
# Check the config is created
assert isinstance(Meta.config, dict) is True
assert Meta.config["meta_config"] == {
"seed": None,
"verbose": True,
"log_path": "logs",
"use_exact_log_path": False,
}
emmental.Meta.update_config(
lr_scheduler = "cosine_annealing"
dirpath = "temp_test_scheduler"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()
Meta.reset()
emmental.init(dirpath)
config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {"lr_scheduler": lr_scheduler},
}
}
emmental.Meta.update_config(config)
emmental_learner.n_batches_per_epoch = 1
emmental_learner._set_optimizer(model)
emmental_learner._set_lr_scheduler(model)
assert emmental_learner.optimizer.param_groups[0]["lr"] == 10
emmental_learner.optimizer.step()
emmental_learner._update_lr_scheduler(model, 0, {})
assert (
abs(emmental_learner.optimizer.param_groups[0]["lr"] - 8.535533905932738) < 1e-5
)
emmental_learner.optimizer.step()
emmental_learner._update_lr_scheduler(model, 1, {})
assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 5) < 1e-5
def test_bert_adam_optimizer(caplog):
"""Unit test of BertAdam optimizer."""
caplog.set_level(logging.INFO)
optimizer = "bert_adam"
dirpath = "temp_test_optimizer"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()
Meta.reset()
emmental.init(dirpath)
# Test default BertAdam setting
config = {"learner_config": {"optimizer_config": {"optimizer": optimizer}}}
emmental.Meta.update_config(config)
emmental_learner._set_optimizer(model)
assert emmental_learner.optimizer.defaults == {
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"weight_decay": 0.0,
}
# Test new BertAdam setting
config = {
"learner_config": {
"optimizer_config": {
"optimizer": optimizer,
"lr": 0.02,
"l2": 0.05,
def test_linear_scheduler(caplog):
"""Unit test of linear scheduler."""
caplog.set_level(logging.INFO)
lr_scheduler = "linear"
dirpath = "temp_test_scheduler"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()
Meta.reset()
emmental.init(dirpath)
# Test per batch
config = {
"learner_config": {
"n_epochs": 4,
"optimizer_config": {"optimizer": "sgd", "lr": 10},
"lr_scheduler_config": {"lr_scheduler": lr_scheduler},
}
}
emmental.Meta.update_config(config)
emmental_learner.n_batches_per_epoch = 1
emmental_learner._set_optimizer(model)
emmental_learner._set_lr_scheduler(model)
assert emmental_learner.optimizer.param_groups[0]["lr"] == 10
def test_log_writer(caplog):
"""Unit test of log_writer."""
caplog.set_level(logging.INFO)
emmental.Meta.reset()
emmental.init()
emmental.Meta.update_config(
config={
"logging_config": {
"counter_unit": "sample",
"evaluation_freq": 10,
"checkpointing": True,
"checkpointer_config": {"checkpoint_freq": 2},
}
}
)
log_writer = LogWriter()
log_writer.add_config(emmental.Meta.config)
def test_adam_optimizer(caplog):
"""Unit test of Adam optimizer."""
caplog.set_level(logging.INFO)
optimizer = "adam"
dirpath = "temp_test_optimizer"
model = nn.Linear(1, 1)
emmental_learner = EmmentalLearner()
Meta.reset()
emmental.init(dirpath)
# Test default Adam setting
config = {"learner_config": {"optimizer_config": {"optimizer": optimizer}}}
emmental.Meta.update_config(config)
emmental_learner._set_optimizer(model)
assert emmental_learner.optimizer.defaults == {
"lr": 0.001,
"betas": (0.9, 0.999),
"eps": 1e-08,
"amsgrad": False,
"weight_decay": 0,
}
# Test new Adam setting
config = {
"learner_config": {
"optimizer_config": {
"optimizer": optimizer,
"lr": 0.02,
emmental.init(dirpath)
# Check the log folder is created correctly
assert os.path.isdir(dirpath) is True
assert Meta.log_path.startswith(dirpath) is True
# Check the config is created
assert isinstance(Meta.config, dict) is True
assert Meta.config["meta_config"] == {
"seed": None,
"verbose": True,
"log_path": "logs",
"use_exact_log_path": False,
}
emmental.Meta.update_config(
path="tests/shared", filename="emmental-test-config.yaml"
)
assert Meta.config["meta_config"] == {
"seed": 1,
"verbose": False,
"log_path": "tests",
"use_exact_log_path": False,
}
# Test unable to find config file
Meta.reset()
emmental.init(dirpath)
emmental.Meta.update_config(path=os.path.dirname(__file__))
assert Meta.config["meta_config"] == {
"seed": None,
assert model.task_names == set(["task_1", "task_2"])
model.remove_task("task_1")
assert model.task_names == set(["task_2"])
model.save(f"{dirpath}/saved_model.pth")
model.load(f"{dirpath}/saved_model.pth")
# Test w/o dataparallel
Meta.reset()
emmental.init(dirpath)
config = {"model_config": {"dataparallel": False}}
emmental.Meta.update_config(config)
model = EmmentalModel(name="test", tasks=task1)
assert repr(model) == "EmmentalModel(name=test)"
assert model.name == "test"
assert model.task_names == set(["task_1"])
assert model.module_pool["m1"].weight.data.size() == (10, 10)
assert model.module_pool["m2"].weight.data.size() == (2, 10)
model.update_task(new_task1)
assert model.module_pool["m1"].weight.data.size() == (5, 10)
assert model.module_pool["m2"].weight.data.size() == (2, 5)
model.update_task(task2)
"optimizer_config": {
"optimizer": optimizer,
"lr": 0.02,
"l2": 0.05,
f"{optimizer}_config": {
"max_iter": 30,
"max_eval": 40,
"tolerance_grad": 1e-04,
"tolerance_change": 1e-05,
"history_size": 10,
"line_search_fn": "strong_wolfe",
},
}
}
}
emmental.Meta.update_config(config)
emmental_learner._set_optimizer(model)
assert emmental_learner.optimizer.defaults == {
"lr": 0.02,
"max_iter": 30,
"max_eval": 40,
"tolerance_grad": 1e-04,
"tolerance_change": 1e-05,
"history_size": 10,
"line_search_fn": "strong_wolfe",
}
shutil.rmtree(dirpath)