Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def train_and_predict_rnn_nn(model, num_hiddens, init_gru_state, corpus_indices, vocab,
device, num_epochs, num_steps, lr,
clipping_theta, batch_size, prefixes, num_layers=1):
"""Train a RNN model and predict the next item in the sequence."""
loss = nn.CrossEntropyLoss()
optm = torch.optim.SGD(model.parameters(), lr=lr)
start = time.time()
for epoch in range(1, num_epochs+1):
l_sum, n = 0.0, 0
data_iter = data_iter_consecutive(
corpus_indices, batch_size, num_steps, device)
state = model.begin_state(batch_size=batch_size, num_hiddens=num_hiddens, device=device ,num_layers=num_layers)
for X, Y in data_iter:
for s in state:
s.detach()
X = X.to(dtype=torch.long)
(output, state) = model(X, state)
y = Y.t().reshape((-1,))
l = loss(output, y.long()).mean()
optm.zero_grad()
l.backward(retain_graph=True)
with torch.no_grad():
# Clip the gradient
grad_clipping_nn(model, clipping_theta, device)
# Since the error has already taken the mean, the gradient does
# not need to be averaged
def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
corpus_indices, vocab, device, is_random_iter,
num_epochs, num_steps, lr, clipping_theta,
batch_size, prefixes):
"""Train an RNN model and predict the next item in the sequence."""
if is_random_iter:
data_iter_fn = data_iter_random
else:
data_iter_fn = data_iter_consecutive
params = get_params()
loss = nn.CrossEntropyLoss()
start = time.time()
for epoch in range(num_epochs):
if not is_random_iter:
# If adjacent sampling is used, the hidden state is initialized
# at the beginning of the epoch
state = init_rnn_state(batch_size, num_hiddens, device)
l_sum, n = 0.0, 0
data_iter = data_iter_fn(corpus_indices, batch_size, num_steps, device)
for X, Y in data_iter:
if is_random_iter:
# If random sampling is used, the hidden state is initialized
# before each mini-batch update
state = init_rnn_state(batch_size, num_hiddens, device)
else: