Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# References
- [DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks
with Low Bitwidth Gradients](https://arxiv.org/abs/1606.06160)
"""
x = tf.clip_by_value(x, 0.0, 1.0)
@tf.custom_gradient
def _k_bit_with_identity_grad(x):
n = 2 ** k_bit - 1
return tf.round(x * n) / n, lambda dy: dy
return _k_bit_with_identity_grad(x)
@utils.register_keras_custom_object
class SteSign(QuantizerFunctionWrapper):
r"""Instantiates a serializable binary quantizer.
\\[
q(x) = \begin{cases}
-1 & x < 0 \\\
1 & x \geq 0
\end{cases}
\\]
The gradient is estimated using the Straight-Through Estimator
(essentially the binarization is replaced by a clipped identity on the
backward pass).
\\[\frac{\partial q(x)}{\partial x} = \begin{cases}
1 & \left|x\right| \leq \texttt{clip_value} \\\
0 & \left|x\right| > \texttt{clip_value}
\end{cases}\\]
```
# Arguments
clip_value: Threshold for clipping gradients. If `None` gradients are not clipped.
# References
- [Binarized Neural Networks: Training Deep Neural Networks with Weights and
Activations Constrained to +1 or -1](http://arxiv.org/abs/1602.02830)
"""
def __init__(self, clip_value=1.0):
super().__init__(ste_sign, clip_value=clip_value)
@utils.register_keras_custom_object
class SteHeaviside(QuantizerFunctionWrapper):
r"""
Instantiates a binarization quantizer with output values 0 and 1.
\\[
q(x) = \begin{cases}
+1 & x > 0 \\\
0 & x \leq 0
\end{cases}
\\]
The gradient is estimated using the Straight-Through Estimator
(essentially the binarization is replaced by a clipped identity on the
backward pass).
\\[\frac{\partial q(x)}{\partial x} = \begin{cases}
1 & \left|x\right| \leq 1 \\\
0 & \left|x\right| > 1
# Arguments
beta: Larger values result in a closer approximation to the derivative of the sign.
# Returns
SwishSign quantization function
# References
- [BNN+: Improved Binary Network Training](https://arxiv.org/abs/1812.11800)
"""
def __init__(self, beta=5.0):
super().__init__(swish_sign, beta=beta)
@utils.register_keras_custom_object
class MagnitudeAwareSign(QuantizerFunctionWrapper):
r"""Instantiates a serializable magnitude-aware sign quantizer for Bi-Real Net.
A scaled sign function computed according to Section 3.3 in
[Zechun Liu et al](https://arxiv.org/abs/1808.00278).
```plot-activation
quantizers._scaled_sign
```
# Arguments
clip_value: Threshold for clipping gradients. If `None` gradients are not clipped.
# References
- [Bi-Real Net: Enhancing the Performance of 1-bit CNNs With Improved
Representational Capability and Advanced Training
Algorithm](https://arxiv.org/abs/1808.00278)
quantizers.ste_heaviside
```
# Arguments
clip_value: Threshold for clipping gradients. If `None` gradients are not clipped.
# Returns
AND Binarization function
"""
def __init__(self, clip_value=1.0):
super().__init__(ste_heaviside, clip_value=clip_value)
@utils.register_keras_custom_object
class SwishSign(QuantizerFunctionWrapper):
r"""Sign binarization function.
\\[
q(x) = \begin{cases}
-1 & x < 0 \\\
1 & x \geq 0
\end{cases}
\\]
The gradient is estimated using the SignSwish method.
\\[
\frac{\partial q_{\beta}(x)}{\partial x} = \frac{\beta\left\\{2-\beta x \tanh \left(\frac{\beta x}{2}\right)\right\\}}{1+\cosh (\beta x)}
\\]
```plot-activation
# Arguments
clip_value: Threshold for clipping gradients. If `None` gradients are not clipped.
# References
- [Bi-Real Net: Enhancing the Performance of 1-bit CNNs With Improved
Representational Capability and Advanced Training
Algorithm](https://arxiv.org/abs/1808.00278)
"""
def __init__(self, clip_value=1.0):
super().__init__(magnitude_aware_sign, clip_value=clip_value)
@utils.register_keras_custom_object
class SteTern(QuantizerFunctionWrapper):
r"""Instantiates a serializable ternarization quantizer.
\\[
q(x) = \begin{cases}
+1 & x > \Delta \\\
0 & |x| < \Delta \\\
-1 & x < - \Delta
\end{cases}
\\]
where \\(\Delta\\) is defined as the threshold and can be passed as an argument,
or can be calculated as per the Ternary Weight Networks original paper, such that
\\[
\Delta = \frac{0.7}{n} \sum_{i=1}^{n} |W_i|
\\]
- [Ternary Weight Networks](http://arxiv.org/abs/1605.04711)
"""
def __init__(
self, threshold_value=0.05, ternary_weight_networks=False, clip_value=1.0
):
super().__init__(
ste_tern,
threshold_value=threshold_value,
ternary_weight_networks=ternary_weight_networks,
clip_value=clip_value,
)
@utils.register_keras_custom_object
class DoReFaQuantizer(QuantizerFunctionWrapper):
r"""Instantiates a serializable k_bit quantizer as in the DoReFa paper.
\\[
q(x) = \begin{cases}
0 & x < \frac{1}{2n} \\\
\frac{i}{n} & \frac{2i-1}{2n} < |x| < \frac{2i+1}{2n} \text{ for } i \in \\{1,n-1\\}\\\
1 & \frac{2n-1}{2n} < x
\end{cases}
\\]
where \\(n = 2^{\text{k_bit}} - 1\\). The number of bits, k_bit, needs to be passed as an argument.
The gradient is estimated using the Straight-Through Estimator
(essentially the binarization is replaced by a clipped identity on the
backward pass).
\\[\frac{\partial q(x)}{\partial x} = \begin{cases}
1 & 0 \leq x \leq 1 \\\