Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# distribution = None
# Pass to RunConfig
config = tf.estimator.RunConfig(
train_distribute=distribution,
eval_distribute=distribution if eval_distributed else None,
model_dir="/tmp/mnist_convnet_model",
)
if save_config is None:
save_config = smd.SaveConfig(save_interval=2)
if include_collections is None:
include_collections = [
CollectionKeys.WEIGHTS,
CollectionKeys.BIASES,
CollectionKeys.GRADIENTS,
CollectionKeys.LOSSES,
]
if not zcc:
ts_hook = smd.SessionHook(
out_dir=trial_dir,
save_all=save_all,
include_collections=include_collections,
save_config=save_config,
reduction_config=reduction_config,
include_workers=include_workers,
)
else:
print("zcc is passed. ignoring include_collections and save_config")
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, config=config)
TENSORFLOW_SUMMARIES = "tensorflow_summaries"
# XGBOOST
HYPERPARAMETERS = "hyperparameters"
METRICS = "metrics"
PREDICTIONS = "predictions"
LABELS = "labels"
FEATURE_IMPORTANCE = "feature_importance"
AVERAGE_SHAP = "average_shap"
FULL_SHAP = "full_shap"
TREES = "trees"
# Collection with summary objects instead of tensors
# so we don't create summaries or reductions of these
SUMMARIES_COLLECTIONS = {CollectionKeys.TENSORFLOW_SUMMARIES}
SCALAR_COLLECTIONS = {
CollectionKeys.LOSSES,
CollectionKeys.METRICS,
CollectionKeys.FEATURE_IMPORTANCE,
CollectionKeys.AVERAGE_SHAP,
CollectionKeys.SM_METRICS,
}
SM_METRIC_COLLECTIONS = {CollectionKeys.LOSSES, CollectionKeys.METRICS, CollectionKeys.SM_METRICS}
# used by pt, mx, keras
NON_REDUCTION_COLLECTIONS = SCALAR_COLLECTIONS.union(SUMMARIES_COLLECTIONS)
NON_HISTOGRAM_COLLECTIONS = SCALAR_COLLECTIONS.union(SUMMARIES_COLLECTIONS)
self._increment_step(env.iteration)
if self.last_saved_step is not None and not self.exported_collections:
self.export_collections()
self.exported_collections = True
if not self._get_collections_to_save_for_step():
self.logger.debug("Skipping iteration {}".format(self.step))
return
self._initialize_writers()
if self._is_collection_being_saved_for_step(CollectionKeys.HYPERPARAMETERS):
self.write_hyperparameters(env)
if self._is_collection_being_saved_for_step(CollectionKeys.METRICS):
self.write_metrics(env)
if self._is_collection_being_saved_for_step(CollectionKeys.PREDICTIONS):
self.write_predictions(env)
if self._is_collection_being_saved_for_step(CollectionKeys.LABELS):
self.write_labels(env)
if self._is_collection_being_saved_for_step(CollectionKeys.FEATURE_IMPORTANCE):
self.write_feature_importances(env)
if self._is_collection_being_saved_for_step(CollectionKeys.TREES):
self.write_tree_model(env)
if self._is_collection_being_saved_for_step(CollectionKeys.FULL_SHAP):
self._maybe_compute_shap_values(env)
"""
colls_with_tensor = set()
for coll in sorted(self._get_all_collections_to_save(), key=lambda x: x.name):
variable_collections_with_tensor, processed = self._process_tensor_from_variable_read_op(
tensor
)
if processed:
colls_with_tensor.update(variable_collections_with_tensor)
# processed=True means this tensor was either a variable read tensor,
# or a tensor with same name as variable
# former will be added to collections such as weights, biases, opt_variables
# latter will be skipped as they refer to the same tensor
else:
# some collections are added automatically, don't match regex for these
if coll.name not in [
CollectionKeys.WEIGHTS,
CollectionKeys.BIASES,
CollectionKeys.OPTIMIZER_VARIABLES,
CollectionKeys.TENSORFLOW_SUMMARIES,
] and match_inc(tensor.name, coll.include_regex):
coll.add(tensor)
if coll.has_tensor(tensor.name):
# it must have been added when collection was added to
# from user(custom_coll)/library(losses, weights, grads)
tensor_ref = coll.get_tensor(tensor.name)
tensor_ref.tf_obj = tensor
colls_with_tensor.add(coll)
# create entry in hook's tensor_to_collections map for this tensor
self._create_tensors_for_matching_collections(tensor, colls_with_tensor)
def _get_collections_to_save_for_step(self) -> Set["Collection"]:
if self._collections_to_save_for_step is None:
self._assert_prep()
self._collections_to_save_for_step = set()
for coll in self._get_all_collections_to_save():
if self.mode in [ModeKeys.EVAL, ModeKeys.PREDICT]:
if coll.name in [CollectionKeys.GRADIENTS, CollectionKeys.OPTIMIZER_VARIABLES]:
continue
if coll.save_config.should_save_step(self.mode, self.mode_steps[self.mode]):
self._collections_to_save_for_step.add(coll)
if self._collections_to_save_for_step:
if self.mode == ModeKeys.GLOBAL:
step_str = f"for step {self.step}"
else:
step_str = f"for step {self.mode_steps[self.mode]} of mode {self.mode.name}"
self.logger.debug(
f"Saving the collections "
f"{', '.join([x.name for x in self._collections_to_save_for_step])} {step_str}"
)
return self._collections_to_save_for_step
def _register_default_collections(self):
self.get(CollectionKeys.HYPERPARAMETERS).include("^hyperparameters/.*$")
self.get(CollectionKeys.METRICS).include("^[a-zA-z]+-[a-zA-z0-9]+$")
self.get(CollectionKeys.PREDICTIONS).include("^predictions$")
self.get(CollectionKeys.LABELS).include("^labels$")
self.get(CollectionKeys.FEATURE_IMPORTANCE).include("^feature_importance/.*")
self.get(CollectionKeys.AVERAGE_SHAP).include("^average_shap/.*[^/bias]$")
self.get(CollectionKeys.FULL_SHAP).include("^full_shap/.*[^/bias]$")
self.get(CollectionKeys.TREES).include("^trees/.*")
self.export_collections()
self.exported_collections = True
if not self._get_collections_to_save_for_step():
self.logger.debug("Skipping iteration {}".format(self.step))
return
self._initialize_writers()
if self._is_collection_being_saved_for_step(CollectionKeys.HYPERPARAMETERS):
self.write_hyperparameters(env)
if self._is_collection_being_saved_for_step(CollectionKeys.METRICS):
self.write_metrics(env)
if self._is_collection_being_saved_for_step(CollectionKeys.PREDICTIONS):
self.write_predictions(env)
if self._is_collection_being_saved_for_step(CollectionKeys.LABELS):
self.write_labels(env)
if self._is_collection_being_saved_for_step(CollectionKeys.FEATURE_IMPORTANCE):
self.write_feature_importances(env)
if self._is_collection_being_saved_for_step(CollectionKeys.TREES):
self.write_tree_model(env)
if self._is_collection_being_saved_for_step(CollectionKeys.FULL_SHAP):
self._maybe_compute_shap_values(env)
self.write_full_shap(env)
if self._is_collection_being_saved_for_step(CollectionKeys.AVERAGE_SHAP):
def _prepare_collections(self):
"""Populate collections_to_save and ensure every collection has
a save_config and reduction_config."""
for c_name, c in self.collection_manager.get_collections().items():
if c in self._collections_to_save:
continue
elif self._should_collection_be_saved(CollectionKeys.ALL):
self._collections_to_save.add(c)
elif self._should_collection_be_saved(c_name):
self._collections_to_save.add(c)
self.logger.info(
f'Monitoring the collections: {", ".join([x.name for x in self._collections_to_save])}'
)
# Populate configs_for_collections and reduction_config
for c_name, c in self.collection_manager.get_collections().items():
if c_name in NON_HISTOGRAM_COLLECTIONS:
c.save_histogram = False
if c.save_config is None:
# Set to the default if None
c.save_config = self.save_config