How to use the smdebug.tensorflow.SessionHook function in smdebug

To help you get started, we’ve selected a few smdebug examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github awslabs / sagemaker-debugger / tests / tensorflow / hooks / test_weights_gradients.py View on Github external
def test_only_w_g(out_dir):
    pre_test_clean_up()
    hook = smd.SessionHook(out_dir, save_all=False, save_config=smd.SaveConfig(save_interval=2))
    helper_test_only_w_g(out_dir, hook)
github awslabs / sagemaker-debugger / tests / core / test_hook_save_scalar.py View on Github external
hook = TF_KerasHook(
            out_dir=trial_dir,
            include_collections=[coll_name],
            save_config=save_config,
            export_tensorboard=True,
        )

        simple_tf_model(hook)

        saved_scalars = [
            "scalar/tf_keras_num_steps",
            "scalar/tf_keras_before_train",
            "scalar/tf_keras_after_train",
        ]
    else:
        hook = TF_SessionHook(
            out_dir=trial_dir,
            include_collections=[coll_name],
            save_config=save_config,
            export_tensorboard=True,
        )

        tf_session_model(hook)
        tf.reset_default_graph()

        saved_scalars = [
            "scalar/tf_session_num_steps",
            "scalar/tf_session_before_train",
            "scalar/tf_session_after_train",
        ]
    hook.close()
    verify_files(trial_dir, save_config, saved_scalars)
github awslabs / sagemaker-debugger / tests / tensorflow / hooks / test_mirrored_strategy.py View on Github external
model_dir="/tmp/mnist_convnet_model",
    )

    if save_config is None:
        save_config = smd.SaveConfig(save_interval=2)

    if include_collections is None:
        include_collections = [
            CollectionKeys.WEIGHTS,
            CollectionKeys.BIASES,
            CollectionKeys.GRADIENTS,
            CollectionKeys.LOSSES,
        ]

    if not zcc:
        ts_hook = smd.SessionHook(
            out_dir=trial_dir,
            save_all=save_all,
            include_collections=include_collections,
            save_config=save_config,
            reduction_config=reduction_config,
            include_workers=include_workers,
        )
    else:
        print("zcc is passed. ignoring include_collections and save_config")

    mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, config=config)
    if steps is None:
        steps = ["train"]

    for s in steps:
        if s == "train":
github awslabs / sagemaker-debugger / tests / tensorflow / hooks / test_write.py View on Github external
def test_hook_write(out_dir):
    pre_test_clean_up()
    # set up hook
    hook = SessionHook(
        out_dir, save_all=True, include_collections=None, save_config=SaveConfig(save_interval=999)
    )
    helper_hook_write(out_dir, hook)
    tr = create_trial_fast_refresh(out_dir)
    print(tr.tensors(collection="weights"))
    assert len(tr.tensors(collection="weights"))
github awslabs / sagemaker-debugger / tests / tensorflow / hooks / test_reductions.py View on Github external
def test_reductions(out_dir, save_raw_tensor=False):
    pre_test_clean_up()
    rdnc = smd.ReductionConfig(
        reductions=ALLOWED_REDUCTIONS,
        abs_reductions=ALLOWED_REDUCTIONS,
        norms=ALLOWED_NORMS,
        abs_norms=ALLOWED_NORMS,
        save_raw_tensor=save_raw_tensor,
    )
    hook = smd.SessionHook(
        out_dir=out_dir,
        save_config=smd.SaveConfig(save_interval=1),
        reduction_config=rdnc,
        include_collections=["weights", "gradients", "losses"],
    )
    helper_test_reductions(out_dir, hook, save_raw_tensor)
github awslabs / sagemaker-debugger / tests / tensorflow / hooks / test_save_config.py View on Github external
def test_save_config(out_dir):
    pre_test_clean_up()
    hook = SessionHook(out_dir=out_dir, save_all=False, save_config=SaveConfig(save_interval=2))
    helper_test_save_config(out_dir, hook)
github awslabs / sagemaker-debugger / examples / tensorflow / scripts / distributed_training / parameter_server_training / parameter_server_mnist.py View on Github external
)

        print("### Doing Multi GPU Training")
    else:
        strategy = None
    # Pass to RunConfig
    config = tf.estimator.RunConfig(train_distribute=strategy)

    # save tensors as reductions if necessary
    rdnc = (
        smd.ReductionConfig(reductions=["mean"], abs_reductions=["max"], norms=["l1"])
        if FLAGS.reductions
        else None
    )

    ts_hook = smd.SessionHook(
        out_dir=FLAGS.smdebug_path,
        save_all=FLAGS.save_all,
        include_collections=["weights", "gradients", "losses", "biases"],
        save_config=smd.SaveConfig(save_interval=FLAGS.save_frequency),
        reduction_config=rdnc,
    )

    ts_hook.set_mode(smd.modes.TRAIN)

    # Create the Estimator
    # pass RunConfig
    mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, config=config)

    hooks = list()
    hooks.append(ts_hook)
github awslabs / sagemaker-debugger / examples / tensorflow / scripts / train_imagenet_resnet_hvd.py View on Github external
abs_reductions = []
    reductions = []
    if FLAGS.tornasole_relu_reductions:
        for r in FLAGS.tornasole_relu_reductions:
            reductions.append(r)
    if FLAGS.tornasole_relu_reductions_abs:
        for r in FLAGS.tornasole_relu_reductions_abs:
            abs_reductions.append(r)
    if reductions or abs_reductions:
        rnc = smd.ReductionConfig(reductions=reductions, abs_reductions=abs_reductions)
    else:
        rnc = None

    include_collections = ["losses"]

    hook = smd.SessionHook(
        out_dir=FLAGS.smdebug_path,
        save_config=smd.SaveConfig(save_interval=FLAGS.step_interval),
        reduction_config=rnc,
        include_collections=include_collections,
        save_all=FLAGS.tornasole_save_all,
    )
    if FLAGS.save_weights is True:
        include_collections.append("weights")
    if FLAGS.save_gradients is True:
        include_collections.append("gradients")
    if FLAGS.tornasole_save_relu_activations is True:
        include_collections.append("relu_activations")
    if FLAGS.tornasole_save_inputs is True:
        include_collections.append("inputs")
    if FLAGS.tornasole_include:
        hook.get_collection("default").include(FLAGS.tornasole_include)