How to use the torchelastic.distributed function in torchelastic

To help you get started, we’ve selected a few torchelastic examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / elastic / examples / imagenet / main.py View on Github external
def _sync_state(self, rank):
        # broadcast from the max rank with the biggest start index
        max_rank, _ = edist.all_gather_return_max_long(self.data_start_index)

        # Broadcast the state from max_rank
        buffer = io.BytesIO()
        self.save(buffer)
        state_tensor = torch.ByteTensor(list(buffer.getvalue()))
        state_size = torch.LongTensor([state_tensor.size()])
        dist.broadcast(state_size, src=max_rank)

        if rank != max_rank:
            state_tensor = torch.ByteTensor([0 for _ in range(state_size[0])])

        dist.broadcast(state_tensor, src=max_rank)

        buffer = io.BytesIO(state_tensor.numpy().tobytes())
        self.load(buffer)
github pytorch / elastic / torchelastic / checkpoint / api.py View on Github external
def load_checkpoint(self, state, rank: int):
        """
        Loads checkpoint if the checkpoint manager has been configured and
        at least one worker has already loaded the checkpoint
        """
        if not self.checkpoint_manager:
            # checkpoint not enabled
            return state

        # all gather `checkpoint_loaded` from all trainers, return true
        # if any trainer have ever loaded checkpoint
        any_checkpoint_loaded = (
            edist.all_gather_return_max_long(1 if self.checkpoint_loaded else 0) == 1
        )

        if any_checkpoint_loaded:
            # checkpoint already loaded by one of the existing trainer
            return state

        # we load checkpoint only if all trainers start from scratch. it is
        # not necessary to load checkpoint if there is a good trainer as new
        # trainer can sync state from it.
        # Start with simple scenario, we always ask one single trainer to
        # load checkpoint and other trainer sync from it
        if rank == 0:
            state = self._do_load_checkpoint(state)

        return state