How to use the tensorflowjs.write_weights.write_weights function in tensorflowjs

To help you get started, we’ve selected a few tensorflowjs examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tensorflow / tfjs-converter / python / tensorflowjs / converters / tf_saved_model_conversion_pb.py View on Github external
tf.import_graph_def(graph_def, name='')
    for const in constants:
      tensor = graph.get_tensor_by_name(const.name + ':0')
      value = tensor.eval(session=sess)
      if not isinstance(value, np.ndarray):
        value = np.array(value)

      # Restore the conditional inputs
      const_manifest.append({'name': const.name, 'data': value})
      const.input[:] = constInputs[const.name]

      # Remove the binary array from tensor and save it to the external file.
      for field_name in CLEARED_TENSOR_FIELDS:
        const.attr["value"].tensor.ClearField(field_name)

  write_weights.write_weights(
      [const_manifest], path, quantization_dtype=quantization_dtype)

  file_io.atomic_write_string_to_file(
      os.path.abspath(output_graph), graph_def.SerializeToString())
github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / keras_h5_conversion.py View on Github external
raise ValueError(
        'Expected weight_shard_size_bytes to be a positive integer, '
        'but got %s' % weight_shard_size_bytes)

  if os.path.isfile(output_dir):
    raise ValueError(
        'Path "%d" already exists as a file (not a directory).' % output_dir)

  model_json = {
      common.FORMAT_KEY: common.TFJS_LAYERS_MODEL_FORMAT,
      common.GENERATED_BY_KEY: _get_generated_by(topology),
      common.CONVERTED_BY_KEY: common.get_converted_by(),
  }

  model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
  weights_manifest = write_weights.write_weights(
      weights, output_dir, write_manifest=False,
      quantization_dtype=quantization_dtype,
      shard_size_bytes=weight_shard_size_bytes)
  assert isinstance(weights_manifest, list)
  model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

  model_json_path = os.path.join(
      output_dir, common.ARTIFACT_MODEL_JSON_FILE_NAME)
  with open(model_json_path, 'wt') as f:
    json.dump(model_json, f)
github tensorflow / tfjs-converter / python / tensorflowjs / converters / tf_saved_model_conversion.py View on Github external
"""Writes weights and topology to the output_dir.

  If `topology` is Falsy (e.g., `None`), only emit weights to output_dir.

  Args:
    topology: tf.GraphDef TensorFlow GraphDef proto object, which represents
      the model topology.
    weights: an array of weight groups (as defined in tfjs write_weights).
    output_graph: the output file name to hold all the contents.
    quantization_dtype: An optional numpy dtype to quantize weights to for
      compression. Only np.uint8 and np.uint16 are supported.
  """
  model_json = {}

  model_json[ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
  weights_manifest = write_weights.write_weights(
      weights, os.path.dirname(output_graph), write_manifest=False,
      quantization_dtype=quantization_dtype)
  assert isinstance(weights_manifest, list)
  model_json[ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

  with open(output_graph, 'wt') as f:
    json.dump(model_json, f)
github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / tf_saved_model_conversion_v2.py View on Github external
tf_version: Tensorflow version of the input graph.
    signature_def: the SignatureDef of the inference graph.
    quantization_dtype: An optional numpy dtype to quantize weights to for
      compression. Only np.uint8 and np.uint16 are supported.
  """
  model_json = {
      common.FORMAT_KEY: common.TFJS_GRAPH_MODEL_FORMAT,
      # TODO(piyu): Add tensorflow version below by using `meta_info_def`.
      common.GENERATED_BY_KEY: tf_version,
      common.CONVERTED_BY_KEY: common.get_converted_by(),
      common.USER_DEFINED_METADATA_KEY: {
          common.SIGNATURE_KEY: MessageToDict(signature_def)
      }
  }
  model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
  weights_manifest = write_weights.write_weights(
      weights, os.path.dirname(output_graph), write_manifest=False,
      quantization_dtype=quantization_dtype)
  assert isinstance(weights_manifest, list)
  model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

  with open(output_graph, 'wt') as f:
    json.dump(model_json, f)