Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
corpus_indices, vocab, device, is_random_iter,
num_epochs, num_steps, lr, clipping_theta,
batch_size, prefixes):
"""Train an RNN model and predict the next item in the sequence."""
if is_random_iter:
data_iter_fn = data_iter_random
else:
data_iter_fn = data_iter_consecutive
params = get_params()
loss = nn.CrossEntropyLoss()
start = time.time()
for epoch in range(num_epochs):
if not is_random_iter:
# If adjacent sampling is used, the hidden state is initialized
# at the beginning of the epoch
state = init_rnn_state(batch_size, num_hiddens, device)
l_sum, n = 0.0, 0
data_iter = data_iter_fn(corpus_indices, batch_size, num_steps, device)
for X, Y in data_iter:
if is_random_iter:
# If random sampling is used, the hidden state is initialized
# before each mini-batch update