How to use the emmental.Meta.config function in emmental

To help you get started, we’ve selected a few emmental examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SenWu / emmental / tests / logging / test_log_writer.py View on Github external
def test_tensorboard_writer(caplog):
    """Unit test of log_writer."""
    caplog.set_level(logging.INFO)

    emmental.Meta.reset()

    emmental.init()

    log_writer = TensorBoardWriter()

    log_writer.add_config(emmental.Meta.config)
    log_writer.add_scalar(name="step 1", value=0.1, step=1)
    log_writer.add_scalar(name="step 2", value=0.2, step=2)

    config_filename = "config.yaml"
    log_writer.write_config(config_filename)

    # Test config
    with open(os.path.join(emmental.Meta.log_path, config_filename), "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    assert config["meta_config"]["verbose"] is True
    assert config["logging_config"]["counter_unit"] == "epoch"
    assert config["logging_config"]["checkpointing"] is False

    log_writer.write_log()
github SenWu / emmental / tests / test_meta.py View on Github external
emmental.Meta.update_config(
        path="tests/shared", filename="emmental-test-config.yaml"
    )
    assert Meta.config["meta_config"] == {
        "seed": 1,
        "verbose": False,
        "log_path": "tests",
        "use_exact_log_path": False,
    }

    # Test unable to find config file
    Meta.reset()
    emmental.init(dirpath)

    emmental.Meta.update_config(path=os.path.dirname(__file__))
    assert Meta.config["meta_config"] == {
        "seed": None,
        "verbose": True,
        "log_path": "logs",
        "use_exact_log_path": False,
    }

    # Remove the temp folder
    shutil.rmtree(dirpath)
github SenWu / emmental / tests / test_meta.py View on Github external
def test_meta(caplog):
    """Unit test of meta."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_meta_log_folder"

    Meta.reset()
    emmental.init(dirpath)

    # Check the log folder is created correctly
    assert os.path.isdir(dirpath) is True
    assert Meta.log_path.startswith(dirpath) is True

    # Check the config is created
    assert isinstance(Meta.config, dict) is True
    assert Meta.config["meta_config"] == {
        "seed": None,
        "verbose": True,
        "log_path": "logs",
        "use_exact_log_path": False,
    }

    emmental.Meta.update_config(
        path="tests/shared", filename="emmental-test-config.yaml"
    )
    assert Meta.config["meta_config"] == {
        "seed": 1,
        "verbose": False,
        "log_path": "tests",
        "use_exact_log_path": False,
    }
github SenWu / emmental / tests / utils / test_parse_args.py View on Github external
"checkpoint_all": True,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": False,
            },
        },
    }

    # Test default and default args are the same
    dirpath = "temp_parse_args"
    Meta.reset()
    emmental.init(dirpath)

    parser = parse_args()
    args = parser.parse_args([])
    config1 = parse_args_to_config(args)
    config2 = emmental.Meta.config

    del config2["learner_config"]["global_evaluation_metric_dict"]
    del config2["learner_config"]["optimizer_config"]["parameters"]
    assert config1 == config2

    shutil.rmtree(dirpath)
github SenWu / emmental / tests / utils / test_parse_args.py View on Github external
# Test different checkpoint_metric
    dirpath = "temp_parse_args"
    Meta.reset()
    emmental.init(
        log_dir=dirpath,
        config={
            "logging_config": {
                "checkpointer_config": {
                    "checkpoint_metric": {"model/valid/all/accuracy": "max"}
                }
            }
        },
    )

    assert emmental.Meta.config == {
        "meta_config": {
            "seed": None,
            "verbose": True,
            "log_path": "logs",
            "use_exact_log_path": False,
        },
        "data_config": {"min_data_len": 0, "max_data_len": 0},
        "model_config": {"model_path": None, "device": 0, "dataparallel": True},
        "learner_config": {
            "fp16": False,
            "n_epochs": 1,
            "train_split": ["train"],
            "valid_split": ["valid"],
            "test_split": ["test"],
            "ignore_index": None,
            "global_evaluation_metric_dict": None,
github SenWu / emmental / src / emmental / learner.py View on Github external
def _set_warmup_scheduler(self, model: EmmentalModel) -> None:
        """Set warmup learning rate scheduler for learning process.

        Args:
          model: The model to set up warmup scheduler.
        """
        self.warmup_steps = 0
        if Meta.config["learner_config"]["lr_scheduler_config"]["warmup_steps"]:
            warmup_steps = Meta.config["learner_config"]["lr_scheduler_config"][
                "warmup_steps"
            ]
            if warmup_steps < 0:
                raise ValueError("warmup_steps much greater or equal than 0.")
            warmup_unit = Meta.config["learner_config"]["lr_scheduler_config"][
                "warmup_unit"
            ]
            if warmup_unit == "epoch":
                self.warmup_steps = int(warmup_steps * self.n_batches_per_epoch)
            elif warmup_unit == "batch":
                self.warmup_steps = int(warmup_steps)
            else:
                raise ValueError(
                    f"warmup_unit must be 'batch' or 'epoch', but {warmup_unit} found."
                )
github SenWu / emmental / src / emmental / logging / logging_manager.py View on Github external
if self.counter_unit not in ["sample", "batch", "epoch"]:
            raise ValueError(f"Unrecognized unit: {self.counter_unit}")

        # Set up evaluation frequency
        self.evaluation_freq = Meta.config["logging_config"]["evaluation_freq"]
        if Meta.config["meta_config"]["verbose"]:
            logger.info(f"Evaluating every {self.evaluation_freq} {self.counter_unit}.")

        if Meta.config["logging_config"]["checkpointing"]:
            self.checkpointing = True

            # Set up checkpointing frequency
            self.checkpointing_freq = int(
                Meta.config["logging_config"]["checkpointer_config"]["checkpoint_freq"]
            )
            if Meta.config["meta_config"]["verbose"]:
                logger.info(
                    f"Checkpointing every "
                    f"{self.checkpointing_freq * self.evaluation_freq} "
                    f"{self.counter_unit}."
                )

            # Set up checkpointer
            self.checkpointer = Checkpointer()
        else:
            self.checkpointing = False
            if Meta.config["meta_config"]["verbose"]:
                logger.info("No checkpointing.")

        # Set up number of samples passed since last evaluation/checkpointing and
        # total number of samples passed since learning process
        self.sample_count: int = 0