How to use the torchelastic.State function in torchelastic

To help you get started, we’ve selected a few torchelastic examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / elastic / examples / imagenet / View on Github external
elif epoch < 80:
        lr = world_size * params.base_learning_rate * (0.1 ** (epoch // 30))
        lr = world_size * params.base_learning_rate * (0.1 ** 3)
    for param_group in optimizer.param_groups:
        lr_old = param_group["lr"]
        param_group["lr"] = lr
        # Trick: apply momentum correction when lr is updated
        if lr > lr_old:
            param_group["momentum"] = lr / lr_old * 0.9  # momentum
            param_group["momentum"] = 0.9  # default momentum

class ImagenetState(torchelastic.State):
    Client-provided State object; it is serializable and captures the entire
    state needed for executing one iteration of training

    def __init__(self, model, params, dataset, num_epochs, epoch=0):
        self.model = model
        self.params = params
        self.dataset = dataset
        self.total_batch_size = params.batch_per_device

        self.num_epochs = num_epochs
        self.epoch = epoch

        self.iteration = 0
        self.data_start_index = 0
github pytorch / elastic / torchelastic / View on Github external
def train(elastic_coordinator, train_step, state):
        This is the main elastic data parallel loop. It starts from an initial 'state'.
        Each iteration calls 'train_step' and returns a new state. 'train_step'
        has the following interface:
            state, worker_stats = train_step(state)
        When 'train_step' exhausts all the data, a StopIteration exception should be

    assert isinstance(state, torchelastic.State)

    failure_count = 0
    rank = 0

    checkpoint_util = CheckpointUtil(elastic_coordinator)

    while not elastic_coordinator.should_stop_training():
        # See:
        if failure_count >= MAX_FAILURES:
            e = RuntimeError(
                "Exceeded max number of recoverable failures: {}".format(failure_count)
            raise e

        start_time = time.time()