How to use the smdebug.core.modes.ModeKeys function in smdebug

To help you get started, we’ve selected a few smdebug examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github awslabs / sagemaker-debugger / tests / core / test_hook_save_scalar.py View on Github external
opt = hook.wrap_optimizer(opt)

    model.compile(
        optimizer=opt,
        loss="sparse_categorical_crossentropy",
        run_eagerly=False,
        metrics=["accuracy"],
    )
    hooks = [hook]
    hook.save_scalar("tf_keras_num_steps", steps, sm_metric=True)

    hook.save_scalar("tf_keras_before_train", 1, sm_metric=False)
    hook.set_mode(ModeKeys.TRAIN)
    model.fit(x_train, y_train, epochs=1, steps_per_epoch=steps, callbacks=hooks, verbose=0)

    hook.set_mode(ModeKeys.EVAL)
    model.evaluate(x_test, y_test, steps=10, callbacks=hooks, verbose=0)
    hook.save_scalar("tf_keras_after_train", 1, sm_metric=False)
github awslabs / sagemaker-debugger / tests / core / test_hook_save_scalar.py View on Github external
)

    opt = tf.train.RMSPropOptimizer(lr)
    opt = hook.wrap_optimizer(opt)

    model.compile(
        optimizer=opt,
        loss="sparse_categorical_crossentropy",
        run_eagerly=False,
        metrics=["accuracy"],
    )
    hooks = [hook]
    hook.save_scalar("tf_keras_num_steps", steps, sm_metric=True)

    hook.save_scalar("tf_keras_before_train", 1, sm_metric=False)
    hook.set_mode(ModeKeys.TRAIN)
    model.fit(x_train, y_train, epochs=1, steps_per_epoch=steps, callbacks=hooks, verbose=0)

    hook.set_mode(ModeKeys.EVAL)
    model.evaluate(x_test, y_test, steps=10, callbacks=hooks, verbose=0)
    hook.save_scalar("tf_keras_after_train", 1, sm_metric=False)
github awslabs / sagemaker-debugger / smdebug / tensorflow / utils.py View on Github external
def get_keras_mode(mode):
    # Should never be called in TF 1.13 where this is not available
    from tensorflow.python.keras.utils.mode_keys import ModeKeys as KerasModeKeys

    if mode == ModeKeys.TRAIN:
        return KerasModeKeys.TRAIN
    elif mode == ModeKeys.EVAL:
        return KerasModeKeys.TEST
    elif mode == ModeKeys.PREDICT:
        return KerasModeKeys.PREDICT
github awslabs / sagemaker-debugger / smdebug / core / hook.py View on Github external
def _get_collections_to_save_for_step(self) -> Set["Collection"]:
        if self._collections_to_save_for_step is None:
            self._assert_prep()
            self._collections_to_save_for_step = set()
            for coll in self._get_all_collections_to_save():
                if self.mode in [ModeKeys.EVAL, ModeKeys.PREDICT]:
                    if coll.name in [CollectionKeys.GRADIENTS, CollectionKeys.OPTIMIZER_VARIABLES]:
                        continue
                if coll.save_config.should_save_step(self.mode, self.mode_steps[self.mode]):
                    self._collections_to_save_for_step.add(coll)

            if self._collections_to_save_for_step:
                if self.mode == ModeKeys.GLOBAL:
                    step_str = f"for step {self.step}"
                else:
                    step_str = f"for step {self.mode_steps[self.mode]} of mode {self.mode.name}"
                self.logger.debug(
                    f"Saving the collections "
                    f"{', '.join([x.name for x in self._collections_to_save_for_step])} {step_str}"
                )
        return self._collections_to_save_for_step
github awslabs / sagemaker-debugger / smdebug / core / tfrecord / tensor_reader.py View on Github external
def _get_mode_modestep(self, step, plugin_data):
        mode_step = step
        mode = ModeKeys.GLOBAL
        for metadata in plugin_data:
            if metadata.plugin_name == MODE_STEP_PLUGIN_NAME:
                mode_step = int(metadata.content)
            if metadata.plugin_name == MODE_PLUGIN_NAME:
                mode = ModeKeys(int(metadata.content))
        return mode, mode_step
github awslabs / sagemaker-debugger / smdebug / core / writer.py View on Github external
def _check_mode_step(mode, mode_step, global_step):
        if mode_step is None:
            mode_step = global_step
        if mode is None:
            mode = ModeKeys.GLOBAL
        if not isinstance(mode, ModeKeys):
            mode_keys = ["ModeKeys." + x.name for x in ModeKeys]
            ex_str = "mode can be one of " + ", ".join(mode_keys)
            raise ValueError(ex_str)
        return mode, mode_step
github awslabs / sagemaker-debugger / smdebug / trials / trial.py View on Github external
def tensor_names(self, *, step=None, mode=ModeKeys.GLOBAL, regex=None, collection=None) -> list:
        self.maybe_refresh()
        ts = set()
        if step is None and mode == ModeKeys.GLOBAL:
            ts.update(self._tensors.keys())
        if step is None and mode != ModeKeys.GLOBAL:
            ts.update(self.mode_to_tensors_map[mode])
        else:
            ts.update(self._tensors_for_step(step, mode))
        self.logger.debug(
            f"getting tensor_names with params: step:{step} mode:{mode} regex:{regex} collection:{collection}"
        )

        if regex is None and collection is None:
            return sorted(list(ts))
        elif regex is not None and collection is not None:
            raise ValueError("Only one of `regex` or `collection` can be passed to this method")
        else:
            if collection is not None:
                xs = set(self._tensors.keys()).intersection(self._tensors_in_collection(collection))
github awslabs / sagemaker-debugger / smdebug / tensorflow / keras.py View on Github external
def _get_exec_function(self, mode):
        if self.distribution_strategy in [
            TFDistributionStrategy.NONE,
            TFDistributionStrategy.HOROVOD,
        ]:
            if mode == ModeKeys.TRAIN:
                x = self.model.train_function
            elif mode == ModeKeys.EVAL:
                x = self.model.test_function
            elif mode == ModeKeys.PREDICT:
                x = self.model.predict_function
            else:
                raise NotImplementedError
        else:
            x = self._get_distributed_model(mode)._distributed_function
        return x