How to use the tensorflowjs.converters.fuse_prelu.register_prelu_func function in tensorflowjs

To help you get started, we’ve selected a few tensorflowjs examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / tf_saved_model_conversion_v2.py View on Github external
def optimize_graph(graph, signature_def, output_graph,
                   tf_version, quantization_dtype=None, skip_op_check=False,
                   strip_debug_ops=False):
  """Takes a Python Graph object and optimizes the graph.

  Args:
    graph: The frozen graph to optimize.
    signature_def: the SignatureDef of the inference graph.
    output_graph: The location of the output graph.
    tf_version: Tensorflow version of the input graph.
    quantization_dtype: An optional numpy dtype to quantize weights to for
      compression. Only np.uint8 and np.uint16 are supported.
    skip_op_check: Bool whether to skip the op check.
    strip_debug_ops: Bool whether to strip debug ops.
  """
  fuse_prelu.register_prelu_func(graph)

  # Add a collection 'train_op' so that Grappler knows the outputs.
  for _, output in signature_def.outputs.items():
    name = output.name.split(':')[0]
    graph.add_to_collection('train_op', graph.get_operation_by_name(name))

  graph_def = graph.as_graph_def()

  unsupported = validate(graph_def.node, skip_op_check,
                         strip_debug_ops)
  if unsupported:
    raise ValueError('Unsupported Ops in the model before optimization\n' +
                     ', '.join(unsupported))

  # first pass of grappler optimization, this is needed for batch norm folding.
  config = config_pb2.ConfigProto()
github tensorflow / tfjs / tfjs-converter / python / tensorflowjs / converters / tf_saved_model_conversion_v2.py View on Github external
signature_def: the SignatureDef of the inference graph.
    quantization_dtype: An optional numpy dtype to quantize weights to for
        compression. Only np.uint8 and np.uint16 are supported.
  """
  constants = [node for node in graph_def.node if node.op == 'Const']
  const_inputs = {}
  # removed the conditional inputs for constants
  for const in constants:
    const_inputs[const.name] = const.input[:]
    del const.input[:]

  print('Writing weight file ' + output_graph + '...')
  const_manifest = []

  graph = tf.Graph()
  fuse_prelu.register_prelu_func(graph)
  fuse_depthwise_conv2d.register_fused_depthwise_conv2d_func(graph)

  extracted_graph = fuse_depthwise_conv2d.extract_op_attributes(graph_def)
  with tf.compat.v1.Session(graph=graph) as sess:
    tf.import_graph_def(extracted_graph, name='')
    for const in constants:
      tensor = graph.get_tensor_by_name(const.name + ':0')
      value = tensor.eval(session=sess)
      if not isinstance(value, np.ndarray):
        value = np.array(value)

      const_manifest.append({'name': const.name, 'data': value})

      # Restore the conditional inputs
      const.input[:] = const_inputs[const.name]