How to use the coremltools.converters.nnssa.commons.basic_graph_ops.delete_node function in coremltools

To help you get started, we’ve selected a few coremltools examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github apple / coremltools / coremltools / converters / nnssa / coreml / graph_pass / op_fusions.py View on Github external
# connect fused node to entry and output nodes
        connect_edge(graph, current_node.name, fused_bn_node.name)
        connect_dests(graph, fused_bn_node.name, bn_outputs)

        # correct output's inputs order
        for out in bn_outputs:
            if len(graph[out].inputs) < 2:
                continue
            out_inputs = graph[out].inputs
            a = out_inputs.index(out_node.name)
            b = out_inputs.index(fused_bn_node.name)
            out_inputs[a], out_inputs[b] = out_inputs[b], out_inputs[a]

        # delete merged nodes
        for name in bn_node_names:
            delete_node(graph, name)
github apple / coremltools / coremltools / converters / nnssa / coreml / graph_pass / op_fusions.py View on Github external
cropping_values[1] = croppings[1, 1]  # right
                needs_cropping_after = False
                border_mode = n.attr.get('padding', '').lower()
                if sum(cropping_values) != 0:
                    if border_mode != 'valid':
                        needs_cropping_after = True
                    else:
                        raise NotImplementedError('unhandled BatchToSpaceND case.')
                if needs_cropping_after:
                    graph[output_node].attr.update({'_cropping_after': cropping_values})

            # adjust type inference
            shape = list(graph[previous_node].datatype.get_shape())
            graph[output_node].datatype = builtins.tensor(graph[output_node].datatype.get_primitive(), tuple(shape))

            delete_node(graph, n.name)
            count += 1

        if count > 0:
            print('[Op Fusion] Skipped {} BatchToSpaceND / SpaceToBatchND nodes.'.format(count))
github apple / coremltools / coremltools / converters / nnssa / coreml / graph_pass / op_fusions.py View on Github external
out_node = gelu_nodes[2]
            gelu_outputs = out_node.outputs[:]

            # Instantiate a new fused node in the graph
            fused_gelu_node = ParsedNode()
            fused_gelu_node.op = 'GeLU'
            fused_gelu_node.name = out_node.name + '_gelu'
            fused_gelu_node.attr = {}
            fused_gelu_node.datatype = current_node.datatype

            graph[fused_gelu_node.name] = fused_gelu_node

            # Delete nodes
            gelu_node_names = [x.name for x in gelu_nodes]
            for name in gelu_node_names:
                delete_node(graph, name)

            # Connect fused node to entry and output nodes
            connect_edge(graph, current_node.name, fused_gelu_node.name)
            connect_dests(graph, fused_gelu_node.name, gelu_outputs)

            count += 1

    if count > 0:
        print('[Op Fusion] Fused {} GeLUs.'.format(count))
github apple / coremltools / coremltools / converters / nnssa / coreml / graph_pass / op_removals.py View on Github external
keys = list(f.graph.keys())
        for k in keys:
            if k not in f.graph:
                continue
            node = f.graph[k]
            if len(node.inputs) != 1 or len(node.outputs) != 1:
                continue
            inp_node = f.graph[node.inputs[0]]
            if node.op == 'Identity' and inp_node.op != 'get_tuple':
                delete_count += 1
                parent_name = f.graph[k].inputs[0]
                disconnect_edge(f.graph, parent_name, k)
                for control_input in f.graph[k].control_inputs:
                    replace_control_dest(f.graph, control_input, k, parent_name)
                replace_node(f.graph, k, parent_name)  # join parent to children
                delete_node(f.graph, k)

    return delete_count
github apple / coremltools / coremltools / converters / nnssa / frontend / tensorflow / graph_pass / fusedbatchnorm_rewrite.py View on Github external
delete_node(graph, node.name)

    builder = GraphBuilder(graph, node.name + '/', ParsedTFNode)
    x_center = builder.add_elementwise("Sub", [x, estimated_mean])
    scaling_factor = builder.add_elementwise(
        "Mul", [scale, builder.add_elementwise("Rsqrt", [estimated_variance])])
    x_scaled = builder.add_elementwise("Mul", [x_center, scaling_factor])
    x_shifted = builder.add_elementwise("Add", [x_scaled, offset])

    x_final = GraphBuilder(graph, '', ParsedTFNode).add_identity(x_shifted, node.name)

    outputs = [x_final]

    for o in original_node_outputs:
        replace_node(graph, o, outputs[graph[o].attr['index']])
        delete_node(graph, o)
github apple / coremltools / coremltools / converters / nnssa / frontend / graph_pass / trace_constants.py View on Github external
self.node_value_trace[cur_node_key] == source_const_key:
                    delete_nodes.append(nodename)
                elif node.op == 'get_tuple':
                    my_trace = self.node_value_trace[(
                        fname,
                        nodename,
                    )]
                    parent_trace = self.node_value_trace[(
                        fname,
                        node.inputs[0],
                    )]
                    if type(parent_trace) is list:
                        node.attr['index'] = parent_trace.index(my_trace)

            if fname == source_fname:
                delete_node(source_fn.graph, source_node_name)
                target_fn.graph[source_node_name] = source_const_node
                for d in delete_nodes:
                    if d != source_node_name:
                        delete_node(fn.graph, d)
            elif fname == target_fname:
                # if this is the target function, we rewrite.
                for d in delete_nodes:
                    assert (len(fn.graph[d].inputs) == 1)
                    fn.graph[d].op = 'Identity'
                    disconnect_edge(fn.graph, fn.graph[d].inputs[0], d)
                    connect_edge(fn.graph, source_const_node.name, d)
                    self.node_value_trace[(
                        fname,
                        d,
                    )] = (
                        fname,
github apple / coremltools / coremltools / converters / nnssa / coreml / graph_pass / op_fusions.py View on Github external
fused_ln_node.op = 'LayerNormalization'
            fused_ln_node.name = out_node.name + '_layernorm'
            fused_ln_node.attr = ln_params
            fused_ln_node.datatype = current_node.datatype

            graph[fused_ln_node.name] = fused_ln_node

            # Connect fused node to entry and output nodes
            connect_edge(graph, current_node.name, fused_ln_node.name)
            replace_node(graph, out_node.name, fused_ln_node.name)
            # connect_dests(graph, fused_ln_node.name, ln_outputs)

            # Delete nodes
            ln_node_names = [x.name for x in ln_nodes]
            for name in ln_node_names:
                delete_node(graph, name)

            count += 1

    if count > 0:
        print('[Op Fusion] Fused {} layer normalizations.'.format(count))