How to use the edward2.layers.LSTMCellReparameterization function in edward2

To help you get started, we’ve selected a few edward2 examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github Google-Health / records-research / model-uncertainty / bayesian_rnn_model.py View on Github external
import constants
import util
import embedding


@tf.custom_gradient
def _clip_gradient(x, clip_norm):
  """Identity function that performs gradient clipping."""
  def grad(dy):
    # NOTE: Must return a gradient for all inputs to `clip_gradient`.
    return tf.clip_by_global_norm([dy], clip_norm)[0][0], tf.constant(0.)
  return x, grad


class LSTMCellReparameterizationGradClipped(
    ed.layers.LSTMCellReparameterization):
  """Bayesian LSTMCell with per-time step gradient clipping."""

  def __init__(self, clip_norm, *args, **kwargs):
    self.clip_norm = clip_norm
    super(LSTMCellReparameterizationGradClipped, self).__init__(*args, **kwargs)

  def call(self, inputs, states, training=None):
    with tf.name_scope("clip_gradient"):
      inputs = _clip_gradient(inputs, self.clip_norm)
      states = [_clip_gradient(x, self.clip_norm) for x in states]
    return super(LSTMCellReparameterizationGradClipped, self).call(
        inputs, states, training)


class BayesianRNN(tf.keras.Model):
  """Bayesian RNN model."""