How to use the tf2onnx.rewriter.rnn_utils.get_weights_from_const_node function in tf2onnx

To help you get started, we’ve selected a few tf2onnx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / lstm_rewriter.py View on Github external
# check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add
        # for bias_add data format
        bias_add = match.get_op("bias_add")
        if bias_add.data_format != "NHWC":
            logger.debug("BiasAdd data_format is not NHWC, SKIP")
            return None

        b_e = match.get_op("cell_bias")
        b = get_weights_from_const_node(self.g, b_e)
        if b is None or b.shape[0] != w.shape[1]:
            logger.warning("cell_kernel and cell_bias's dimensions does not match, skip")
            return None

        ft_bias_node = match.get_op("ft_bias")
        ft_bias = get_weights_from_const_node(self.g, ft_bias_node)
        if ft_bias is None:
            return None

        if not b.dtype == ft_bias.dtype:
            return None

        return {
            "weight": w,
            "bias": b,
            "ft_bias": ft_bias
        }
github onnx / tensorflow-onnx / tf2onnx / rewriter / lstm_rewriter.py View on Github external
# check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add
        # for bias_add data format
        bias_add = match.get_op("bias_add")
        if bias_add.data_format != "NHWC":
            log.debug("BiasAdd data_format is not NHWC, SKIP")
            return None

        b_e = match.get_op("cell_bias")
        b = get_weights_from_const_node(b_e)
        if not b or b.value.shape[0] != w.value.shape[1]:
            log.warning("cell_kernel and cell_bias's dimensions does not match, skip")
            return None

        ft_bias = match.get_op("ft_bias")
        ft = get_weights_from_const_node(ft_bias)
        if not ft:
            return None

        if not (len(ft.value) == 1 and b_e.dtype == ft_bias.dtype):
            return None

        return RnnWeights(w, b, ft)
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / lstm_rewriter.py View on Github external
match = context.cell_match

        w_e = match.get_op("cell_kernel")
        w = get_weights_from_const_node(self.g, w_e)
        if w is None:
            return None

        # check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add
        # for bias_add data format
        bias_add = match.get_op("bias_add")
        if bias_add.data_format != "NHWC":
            logger.debug("BiasAdd data_format is not NHWC, SKIP")
            return None

        b_e = match.get_op("cell_bias")
        b = get_weights_from_const_node(self.g, b_e)
        if b is None or b.shape[0] != w.shape[1]:
            logger.warning("cell_kernel and cell_bias's dimensions does not match, skip")
            return None

        ft_bias_node = match.get_op("ft_bias")
        ft_bias = get_weights_from_const_node(self.g, ft_bias_node)
        if ft_bias is None:
            return None

        if not b.dtype == ft_bias.dtype:
            return None

        return {
            "weight": w,
            "bias": b,
            "ft_bias": ft_bias
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / gru_rewriter.py View on Github external
}

        # differ on memory gate:
        # GRUCell: h'_t = tanh(concat(x_t, r_t .* h_t-1) * W + b)
        # CudnnCompatibleGRUCell: h'_t = tanh(x_t * W_x + b_x + r_t .* (h_t-1 * W_h + b_h))
        if self.gru_cell_type == RNNUnitType.CudnnCompatibleGRUCell:
            hidden_state_kernel = get_weights_from_const_node(
                self.g, match.get_op("hidden_state_kernel")
            )
            hidden_state_bias = get_weights_from_const_node(
                self.g, match.get_op("hidden_state_bias")
            )
            hidden_input_kernel = get_weights_from_const_node(
                self.g, match.get_op("hidden_input_kernel")
            )
            hidden_input_bias = get_weights_from_const_node(
                self.g, match.get_op("hidden_input_bias")
            )
            if not all(val is not None for val in [
                    hidden_state_kernel, hidden_state_bias,
                    hidden_input_kernel, hidden_input_bias
            ]):
                logger.debug("rnn weights check failed, skip")
                return None
            hidden_kernel = np.concatenate([hidden_input_kernel, hidden_state_kernel])
            # apply the linear transformation before multiplying by the output of reset gate
            context.attributes["linear_before_reset"] = 1
            res["hidden_kernel"] = hidden_kernel
            res["hidden_bias"] = hidden_input_bias
            # recurrence bias for hidden gate
            res["Rb_h"] = hidden_state_bias
        elif self.gru_cell_type in [RNNUnitType.GRUCell, RNNUnitType.GRUBlockCell]:
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / gru_rewriter.py View on Github external
gate_kernel = get_weights_from_const_node(self.g, match.get_op("gate_kernel"))
        gate_bias = get_weights_from_const_node(self.g, match.get_op("gate_bias"))
        res = {
            "gate_kernel": gate_kernel,
            "gate_bias": gate_bias
        }

        # differ on memory gate:
        # GRUCell: h'_t = tanh(concat(x_t, r_t .* h_t-1) * W + b)
        # CudnnCompatibleGRUCell: h'_t = tanh(x_t * W_x + b_x + r_t .* (h_t-1 * W_h + b_h))
        if self.gru_cell_type == RNNUnitType.CudnnCompatibleGRUCell:
            hidden_state_kernel = get_weights_from_const_node(
                self.g, match.get_op("hidden_state_kernel")
            )
            hidden_state_bias = get_weights_from_const_node(
                self.g, match.get_op("hidden_state_bias")
            )
            hidden_input_kernel = get_weights_from_const_node(
                self.g, match.get_op("hidden_input_kernel")
            )
            hidden_input_bias = get_weights_from_const_node(
                self.g, match.get_op("hidden_input_bias")
            )
            if not all(val is not None for val in [
                    hidden_state_kernel, hidden_state_bias,
                    hidden_input_kernel, hidden_input_bias
            ]):
                logger.debug("rnn weights check failed, skip")
                return None
            hidden_kernel = np.concatenate([hidden_input_kernel, hidden_state_kernel])
            # apply the linear transformation before multiplying by the output of reset gate
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / gru_rewriter.py View on Github external
res = {
            "gate_kernel": gate_kernel,
            "gate_bias": gate_bias
        }

        # differ on memory gate:
        # GRUCell: h'_t = tanh(concat(x_t, r_t .* h_t-1) * W + b)
        # CudnnCompatibleGRUCell: h'_t = tanh(x_t * W_x + b_x + r_t .* (h_t-1 * W_h + b_h))
        if self.gru_cell_type == RNNUnitType.CudnnCompatibleGRUCell:
            hidden_state_kernel = get_weights_from_const_node(
                self.g, match.get_op("hidden_state_kernel")
            )
            hidden_state_bias = get_weights_from_const_node(
                self.g, match.get_op("hidden_state_bias")
            )
            hidden_input_kernel = get_weights_from_const_node(
                self.g, match.get_op("hidden_input_kernel")
            )
            hidden_input_bias = get_weights_from_const_node(
                self.g, match.get_op("hidden_input_bias")
            )
            if not all(val is not None for val in [
                    hidden_state_kernel, hidden_state_bias,
                    hidden_input_kernel, hidden_input_bias
            ]):
                logger.debug("rnn weights check failed, skip")
                return None
            hidden_kernel = np.concatenate([hidden_input_kernel, hidden_state_kernel])
            # apply the linear transformation before multiplying by the output of reset gate
            context.attributes["linear_before_reset"] = 1
            res["hidden_kernel"] = hidden_kernel
            res["hidden_bias"] = hidden_input_bias
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / gru_rewriter.py View on Github external
)
            if not all(val is not None for val in [
                    hidden_state_kernel, hidden_state_bias,
                    hidden_input_kernel, hidden_input_bias
            ]):
                logger.debug("rnn weights check failed, skip")
                return None
            hidden_kernel = np.concatenate([hidden_input_kernel, hidden_state_kernel])
            # apply the linear transformation before multiplying by the output of reset gate
            context.attributes["linear_before_reset"] = 1
            res["hidden_kernel"] = hidden_kernel
            res["hidden_bias"] = hidden_input_bias
            # recurrence bias for hidden gate
            res["Rb_h"] = hidden_state_bias
        elif self.gru_cell_type in [RNNUnitType.GRUCell, RNNUnitType.GRUBlockCell]:
            hidden_kernel = get_weights_from_const_node(self.g, match.get_op("hidden_kernel"))
            hidden_bias = get_weights_from_const_node(self.g, match.get_op("hidden_bias"))
            res["hidden_kernel"] = hidden_kernel
            res["hidden_bias"] = hidden_bias

        if not all(val is not None for val in res.values()):
            logger.debug("rnn weights check failed, skip")
            return None

        logger.debug("find needed weights")
        return res
github onnx / keras-onnx / keras2onnx / ktf2onnx / tf2onnx / rewriter / gru_rewriter.py View on Github external
def get_weight_and_bias(self, context):
        match = context.cell_match

        gate_kernel = get_weights_from_const_node(self.g, match.get_op("gate_kernel"))
        gate_bias = get_weights_from_const_node(self.g, match.get_op("gate_bias"))
        res = {
            "gate_kernel": gate_kernel,
            "gate_bias": gate_bias
        }

        # differ on memory gate:
        # GRUCell: h'_t = tanh(concat(x_t, r_t .* h_t-1) * W + b)
        # CudnnCompatibleGRUCell: h'_t = tanh(x_t * W_x + b_x + r_t .* (h_t-1 * W_h + b_h))
        if self.gru_cell_type == RNNUnitType.CudnnCompatibleGRUCell:
            hidden_state_kernel = get_weights_from_const_node(
                self.g, match.get_op("hidden_state_kernel")
            )
            hidden_state_bias = get_weights_from_const_node(
                self.g, match.get_op("hidden_state_bias")
            )
github onnx / tensorflow-onnx / tf2onnx / rewriter / grublock_rewriter.py View on Github external
def get_weight_and_bias(self, match):

        node = match.get_op("GRUBlockCell")
        # from tf, it can be known that, the inputs index and meaning of input data is:
        # 0-input, 1-state, 2-gate_kernel, 3-hidden_kernel, 4-gate_bias, 5-hidden_bias
        gate_kernel = get_weights_from_const_node(self.g, node.inputs[2].inputs[0])
        gate_bias = get_weights_from_const_node(self.g, node.inputs[4].inputs[0])
        hidden_kernel = get_weights_from_const_node(self.g, node.inputs[3].inputs[0])
        hidden_bias = get_weights_from_const_node(self.g, node.inputs[5].inputs[0])
        if not all([gate_kernel, gate_bias, hidden_kernel, hidden_bias]):
            log.debug("rnn weights check failed, skip")
            return None
        log.debug("find needed weights")
        res = {"gate_kernel": gate_kernel,
               "gate_bias": gate_bias,
               "hidden_kernel": hidden_kernel,
               "hidden_bias": hidden_bias}
        return res
github onnx / tensorflow-onnx / tf2onnx / rewriter / lstm_rewriter.py View on Github external
def get_weight_and_bias(self, match):
        # if one of them is not match, just return
        w_e = match.get_op("cell_kernel")
        w = get_weights_from_const_node(w_e)
        if not w:
            return None

        # check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add
        # for bias_add data format
        bias_add = match.get_op("bias_add")
        if bias_add.data_format != "NHWC":
            log.debug("BiasAdd data_format is not NHWC, SKIP")
            return None

        b_e = match.get_op("cell_bias")
        b = get_weights_from_const_node(b_e)
        if not b or b.value.shape[0] != w.value.shape[1]:
            log.warning("cell_kernel and cell_bias's dimensions does not match, skip")
            return None