How to use the torchelastic.SimpleWorkerStats function in torchelastic

To help you get started, we’ve selected a few torchelastic examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / elastic / examples / imagenet / main.py View on Github external
if dist.get_rank() % torch.cuda.device_count() == 0:
        data_idx = state.data_start_index + (state.iteration * state.total_batch_size)
        log.info(
            f"epoch: {state.epoch}, iteration: {state.iteration}, data_idx: {data_idx}"
        )

    state.data_start_index += world_size * state.total_batch_size
    state.iteration += 1
    state.model_state = state.dist_model.state_dict()

    end = time.time()
    # each train_step processes one mini_batch
    # measuring wall-clock time on the host may not be totally accurate
    # as CUDA kernels are asynchronous, this is for illustration purposes only
    batch_per_sec = 1 / (end - start)
    return state, torchelastic.SimpleWorkerStats(batch_per_sec)