Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
tf.import_graph_def(graph_def, name='')
for const in constants:
tensor = graph.get_tensor_by_name(const.name + ':0')
value = tensor.eval(session=sess)
if not isinstance(value, np.ndarray):
value = np.array(value)
# Restore the conditional inputs
const_manifest.append({'name': const.name, 'data': value})
const.input[:] = constInputs[const.name]
# Remove the binary array from tensor and save it to the external file.
for field_name in CLEARED_TENSOR_FIELDS:
const.attr["value"].tensor.ClearField(field_name)
write_weights.write_weights(
[const_manifest], path, quantization_dtype=quantization_dtype)
file_io.atomic_write_string_to_file(
os.path.abspath(output_graph), graph_def.SerializeToString())
raise ValueError(
'Expected weight_shard_size_bytes to be a positive integer, '
'but got %s' % weight_shard_size_bytes)
if os.path.isfile(output_dir):
raise ValueError(
'Path "%d" already exists as a file (not a directory).' % output_dir)
model_json = {
common.FORMAT_KEY: common.TFJS_LAYERS_MODEL_FORMAT,
common.GENERATED_BY_KEY: _get_generated_by(topology),
common.CONVERTED_BY_KEY: common.get_converted_by(),
}
model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
weights_manifest = write_weights.write_weights(
weights, output_dir, write_manifest=False,
quantization_dtype=quantization_dtype,
shard_size_bytes=weight_shard_size_bytes)
assert isinstance(weights_manifest, list)
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest
model_json_path = os.path.join(
output_dir, common.ARTIFACT_MODEL_JSON_FILE_NAME)
with open(model_json_path, 'wt') as f:
json.dump(model_json, f)
"""Writes weights and topology to the output_dir.
If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.
Args:
topology: tf.GraphDef TensorFlow GraphDef proto object, which represents
the model topology.
weights: an array of weight groups (as defined in tfjs write_weights).
output_graph: the output file name to hold all the contents.
quantization_dtype: An optional numpy dtype to quantize weights to for
compression. Only np.uint8 and np.uint16 are supported.
"""
model_json = {}
model_json[ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
weights_manifest = write_weights.write_weights(
weights, os.path.dirname(output_graph), write_manifest=False,
quantization_dtype=quantization_dtype)
assert isinstance(weights_manifest, list)
model_json[ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest
with open(output_graph, 'wt') as f:
json.dump(model_json, f)
tf_version: Tensorflow version of the input graph.
signature_def: the SignatureDef of the inference graph.
quantization_dtype: An optional numpy dtype to quantize weights to for
compression. Only np.uint8 and np.uint16 are supported.
"""
model_json = {
common.FORMAT_KEY: common.TFJS_GRAPH_MODEL_FORMAT,
# TODO(piyu): Add tensorflow version below by using `meta_info_def`.
common.GENERATED_BY_KEY: tf_version,
common.CONVERTED_BY_KEY: common.get_converted_by(),
common.USER_DEFINED_METADATA_KEY: {
common.SIGNATURE_KEY: MessageToDict(signature_def)
}
}
model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
weights_manifest = write_weights.write_weights(
weights, os.path.dirname(output_graph), write_manifest=False,
quantization_dtype=quantization_dtype)
assert isinstance(weights_manifest, list)
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest
with open(output_graph, 'wt') as f:
json.dump(model_json, f)