Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_only_w_g(out_dir):
pre_test_clean_up()
hook = smd.SessionHook(out_dir, save_all=False, save_config=smd.SaveConfig(save_interval=2))
helper_test_only_w_g(out_dir, hook)
hook = TF_KerasHook(
out_dir=trial_dir,
include_collections=[coll_name],
save_config=save_config,
export_tensorboard=True,
)
simple_tf_model(hook)
saved_scalars = [
"scalar/tf_keras_num_steps",
"scalar/tf_keras_before_train",
"scalar/tf_keras_after_train",
]
else:
hook = TF_SessionHook(
out_dir=trial_dir,
include_collections=[coll_name],
save_config=save_config,
export_tensorboard=True,
)
tf_session_model(hook)
tf.reset_default_graph()
saved_scalars = [
"scalar/tf_session_num_steps",
"scalar/tf_session_before_train",
"scalar/tf_session_after_train",
]
hook.close()
verify_files(trial_dir, save_config, saved_scalars)
model_dir="/tmp/mnist_convnet_model",
)
if save_config is None:
save_config = smd.SaveConfig(save_interval=2)
if include_collections is None:
include_collections = [
CollectionKeys.WEIGHTS,
CollectionKeys.BIASES,
CollectionKeys.GRADIENTS,
CollectionKeys.LOSSES,
]
if not zcc:
ts_hook = smd.SessionHook(
out_dir=trial_dir,
save_all=save_all,
include_collections=include_collections,
save_config=save_config,
reduction_config=reduction_config,
include_workers=include_workers,
)
else:
print("zcc is passed. ignoring include_collections and save_config")
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, config=config)
if steps is None:
steps = ["train"]
for s in steps:
if s == "train":
def test_hook_write(out_dir):
pre_test_clean_up()
# set up hook
hook = SessionHook(
out_dir, save_all=True, include_collections=None, save_config=SaveConfig(save_interval=999)
)
helper_hook_write(out_dir, hook)
tr = create_trial_fast_refresh(out_dir)
print(tr.tensors(collection="weights"))
assert len(tr.tensors(collection="weights"))
def test_reductions(out_dir, save_raw_tensor=False):
pre_test_clean_up()
rdnc = smd.ReductionConfig(
reductions=ALLOWED_REDUCTIONS,
abs_reductions=ALLOWED_REDUCTIONS,
norms=ALLOWED_NORMS,
abs_norms=ALLOWED_NORMS,
save_raw_tensor=save_raw_tensor,
)
hook = smd.SessionHook(
out_dir=out_dir,
save_config=smd.SaveConfig(save_interval=1),
reduction_config=rdnc,
include_collections=["weights", "gradients", "losses"],
)
helper_test_reductions(out_dir, hook, save_raw_tensor)
def test_save_config(out_dir):
pre_test_clean_up()
hook = SessionHook(out_dir=out_dir, save_all=False, save_config=SaveConfig(save_interval=2))
helper_test_save_config(out_dir, hook)
)
print("### Doing Multi GPU Training")
else:
strategy = None
# Pass to RunConfig
config = tf.estimator.RunConfig(train_distribute=strategy)
# save tensors as reductions if necessary
rdnc = (
smd.ReductionConfig(reductions=["mean"], abs_reductions=["max"], norms=["l1"])
if FLAGS.reductions
else None
)
ts_hook = smd.SessionHook(
out_dir=FLAGS.smdebug_path,
save_all=FLAGS.save_all,
include_collections=["weights", "gradients", "losses", "biases"],
save_config=smd.SaveConfig(save_interval=FLAGS.save_frequency),
reduction_config=rdnc,
)
ts_hook.set_mode(smd.modes.TRAIN)
# Create the Estimator
# pass RunConfig
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, config=config)
hooks = list()
hooks.append(ts_hook)
abs_reductions = []
reductions = []
if FLAGS.tornasole_relu_reductions:
for r in FLAGS.tornasole_relu_reductions:
reductions.append(r)
if FLAGS.tornasole_relu_reductions_abs:
for r in FLAGS.tornasole_relu_reductions_abs:
abs_reductions.append(r)
if reductions or abs_reductions:
rnc = smd.ReductionConfig(reductions=reductions, abs_reductions=abs_reductions)
else:
rnc = None
include_collections = ["losses"]
hook = smd.SessionHook(
out_dir=FLAGS.smdebug_path,
save_config=smd.SaveConfig(save_interval=FLAGS.step_interval),
reduction_config=rnc,
include_collections=include_collections,
save_all=FLAGS.tornasole_save_all,
)
if FLAGS.save_weights is True:
include_collections.append("weights")
if FLAGS.save_gradients is True:
include_collections.append("gradients")
if FLAGS.tornasole_save_relu_activations is True:
include_collections.append("relu_activations")
if FLAGS.tornasole_save_inputs is True:
include_collections.append("inputs")
if FLAGS.tornasole_include:
hook.get_collection("default").include(FLAGS.tornasole_include)