How to use the d2l.data.data_iter_consecutive function in d2l

To help you get started, we’ve selected a few d2l examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github dsgiitr / d2l-pytorch / d2l / train.py View on Github external
def train_and_predict_rnn_nn(model, num_hiddens, init_gru_state, corpus_indices, vocab,
                                device, num_epochs, num_steps, lr,
                                clipping_theta, batch_size, prefixes, num_layers=1):
    """Train a RNN model and predict the next item in the sequence."""
    loss =  nn.CrossEntropyLoss()
    optm = torch.optim.SGD(model.parameters(), lr=lr)
    start = time.time()
    for epoch in range(1, num_epochs+1):
        l_sum, n = 0.0, 0
        data_iter = data_iter_consecutive(
            corpus_indices, batch_size, num_steps, device)
        state = model.begin_state(batch_size=batch_size, num_hiddens=num_hiddens, device=device ,num_layers=num_layers)
        for X, Y in data_iter:
            for s in state:
                s.detach()
            X = X.to(dtype=torch.long)
            (output, state) = model(X, state)
            y = Y.t().reshape((-1,))
            l = loss(output, y.long()).mean()
            optm.zero_grad()
            l.backward(retain_graph=True)
            with torch.no_grad():
                # Clip the gradient
                grad_clipping_nn(model, clipping_theta, device)
                # Since the error has already taken the mean, the gradient does
                # not need to be averaged
github dsgiitr / d2l-pytorch / d2l / train.py View on Github external
def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
                          corpus_indices, vocab, device, is_random_iter,
                          num_epochs, num_steps, lr, clipping_theta,
                          batch_size, prefixes):
    """Train an RNN model and predict the next item in the sequence."""
    if is_random_iter:
        data_iter_fn = data_iter_random
    else:
        data_iter_fn = data_iter_consecutive
    params = get_params()
    loss =  nn.CrossEntropyLoss()
    start = time.time()
    for epoch in range(num_epochs):
        if not is_random_iter:
            # If adjacent sampling is used, the hidden state is initialized
            # at the beginning of the epoch
            state = init_rnn_state(batch_size, num_hiddens, device)
        l_sum, n = 0.0, 0
        data_iter = data_iter_fn(corpus_indices, batch_size, num_steps, device)
        for X, Y in data_iter:
            if is_random_iter:
                # If random sampling is used, the hidden state is initialized
                # before each mini-batch update
                state = init_rnn_state(batch_size, num_hiddens, device)
            else: