How to use the torchelastic.coordinator.StopException function in torchelastic

To help you get started, we’ve selected a few torchelastic examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github pytorch / elastic / torchelastic / train_loop.py View on Github external
store, rank, world_size = elastic_coordinator.rendezvous_barrier()
            elastic_coordinator.init_process_group()

            # load checkpoint if necessary
            state = checkpoint_util.load_checkpoint(state, rank)

            state_sync_start_time = time.time()
            state.sync(world_size, rank)
            publish_metric(
                "torchelastic",
                "state_sync.duration.ms",
                get_elapsed_time_ms(state_sync_start_time),
            )
            checkpoint_util.set_checkpoint_loaded()
            log.info("Rank {0} synced state with other nodes".format(rank))
        except StopException:
            log.info("Rank {0} received stopped signal. Exiting training.".format(rank))
            break
        except RuntimeError as e:
            # See: https://github.com/pytorch/elastic/issues/7
            elastic_coordinator.on_error(e)
            state.apply_snapshot(snapshot)
            failure_count += 1
            continue
        except (NonRetryableException, Exception) as e:
            elastic_coordinator.on_error(e)
            raise
        finally:
            publish_metric(
                "torch_elastic",
                "outer_train_loop.duration.ms",
                get_elapsed_time_ms(start_time),
github pytorch / elastic / torchelastic / p2p / coordinator_p2p.py View on Github external
def rendezvous_barrier(self):
        self._destroy_process_group()
        try:
            self.store, self.rank, self.world_size = self.rendezvous.next_rendezvous()
        except RendezvousClosedException:
            # Sets the local variable to True
            self.stop_training = True
            raise StopException(
                "Rank {0} received RendezvousClosedException."
                " Raising a StopException".format(self.rank)
            )
        except (RuntimeError, Exception) as e:
            raise NonRetryableException(
                "Rank {0} received an Exception."
                " Detailed message: {1}".format(self.rank, str(e))
            )
        log.info(
            "Got next rendezvous: rank {0}, world size {1}".format(
                self.rank, self.world_size
            )
        )

        # Assume straggler state is unreliable after rendezvous
        self.is_worker_straggler = False