Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
elif epoch < 80:
lr = world_size * params.base_learning_rate * (0.1 ** (epoch // 30))
else:
lr = world_size * params.base_learning_rate * (0.1 ** 3)
for param_group in optimizer.param_groups:
lr_old = param_group["lr"]
param_group["lr"] = lr
# Trick: apply momentum correction when lr is updated
if lr > lr_old:
param_group["momentum"] = lr / lr_old * 0.9 # momentum
else:
param_group["momentum"] = 0.9 # default momentum
return
class ImagenetState(torchelastic.State):
"""
Client-provided State object; it is serializable and captures the entire
state needed for executing one iteration of training
"""
def __init__(self, model, params, dataset, num_epochs, epoch=0):
self.model = model
self.params = params
self.dataset = dataset
self.total_batch_size = params.batch_per_device
self.num_epochs = num_epochs
self.epoch = epoch
self.iteration = 0
self.data_start_index = 0
def _train_rerendezvous(self, _, run_id, train_step, hooks, state_override=None):
"""
Alternate sub-process trainer entry point used by tests that want to
force a re-rendezvous after every iteration.
"""
class RerendezvousCoordinatorP2P(CoordinatorP2P):
def should_rendezvous(self, state):
return True
elastic_coordinator = RerendezvousCoordinatorP2P(
c10d_backend="gloo",
init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
max_num_trainers=self.max_size,
process_group_timeout=10000,
)
state = self._train_common(
_, elastic_coordinator, train_step, hooks, state_override
)
return state
def test_normal_flow_with_worker_stats(self):
"""
Test a very simple 4 trainer case, where elastic_train_step
also returns a non-None WorkerStats instance.
"""
run_id = self._generate_run_id()
nprocs = 4
qouts = []
qerrs = []
prog_rates = [100, 95, 42, None]
CoordinatorP2P.MONITOR_PROGRESS_FREQ = 1
original_monitor_progress = CoordinatorP2P.monitor_progress
def patched_monitor_progress(self, state, worker_stats):
original_monitor_progress(self, state, worker_stats)
# Save into state for retrieval in `_get_or_raise` below.
if hasattr(self, "last_relative_prog_rate"):
state._test_relative_prog_rate = self.last_relative_prog_rate
if hasattr(self, "is_worker_straggler"):
state._test_is_straggler = self.is_worker_straggler
with patch.object(CoordinatorP2P, "monitor_progress", patched_monitor_progress):
for i in range(0, nprocs):
_, qout, qerr = self._spawn(
self._train_with_worker_stats,
run_id,
def _train(self, _, run_id, train_step, hooks, state_override=None):
"""
Common sub-process trainer entry point used by most tests.
"""
elastic_coordinator = CoordinatorP2P(
c10d_backend="gloo",
init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
max_num_trainers=self.max_size,
process_group_timeout=10000,
)
return self._train_common(
_, elastic_coordinator, train_step, hooks, state_override
)
def _train_with_worker_stats(
self,
_,
run_id,
train_step,
hooks,
state_override=None,
worker_stats_progress_rate=None,
):
"""
Similar to `_train`, but uses a coordinator that validates WorkerStats object
"""
fixed_worker_stats = TestWorkerStats(progress_rate=worker_stats_progress_rate)
elastic_coordinator = CoordinatorP2P(
c10d_backend="gloo",
init_method=self.get_rdzv_url(run_id, self.min_size, self.max_size),
max_num_trainers=self.max_size,
process_group_timeout=10000,
)
return self._train_common(
_,
elastic_coordinator,
train_step,
hooks,
state_override,
fixed_worker_stats,
)
def _train_common(
self,
_,
elastic_coordinator,
train_step,
hooks,
state_override=None,
worker_stats=None,
):
state = TestState() if state_override is None else state_override
elastic_train_step = _make_elastic_train_step(train_step, hooks, worker_stats)
state = elastic_train_loop.train(elastic_coordinator, elastic_train_step, state)
return state
def process_retryable_exception():
# Raise exception repeatedly
raise RuntimeError("train_step throws RuntimeError (retryable exception)")
hooks = {"process_retryable_exception": process_retryable_exception}
nprocs = 4
qouts = []
qerrs = []
for _ in range(0, nprocs - 2):
_, qout, qerr = self._spawn(self._train, run_id, _train_step, None)
qouts.append(qout)
qerrs.append(qerr)
with patch.object(elastic_train_loop, "MAX_FAILURES", 5):
for _ in range(nprocs - 2, nprocs):
_, qout, qerr = self._spawn(self._train, run_id, _train_step, hooks)
qouts.append(qout)
qerrs.append(qerr)
# Gather all "trained" values from all trainers, and ensure
# that the bad trainers raise the expected exception.
sums = []
for i in range(0, nprocs):
if i <= 1:
state = _get_or_raise(qouts[i], qerrs[i])
sums.append(state.total_sum)
# Initially, 4 trainers consume 2 samples each, then the
# surviving 2 trainers divide the remaining 20-8=12 samples, so
# the surviving trainers each successfully process 2+6=8 samples.
# nums keeps track of the samples "seen" so the surviving trainers
def join_rendezvous(self, expected_version):
# Use compare-and-swap to add self to rendezvous state:
while True:
cas_delay()
active_version, state = self.get_rdzv_state()
if state["status"] != "joinable":
raise EtcdRendezvousRetryableFailure(
"Rendezvous state became non-joinable before we could join. "
"Must join next one."
)
if state["version"] != expected_version:
raise EtcdRendezvousRetryImmediately(
"Rendezvous version changed. Must try join the new one."
)
assert (
len(state["participants"]) < self._num_max_workers
), "Logic error: joinable rendezvous should always have space left"
this_rank = len(state["participants"])
state["participants"].append(this_rank)
# When reaching min workers, or changing state to frozen, we'll set
# the active_version node to be ephemeral.
if len(state["participants"]) == self._num_max_workers:
state["status"] = "frozen"
state["keep_alives"] = []
set_ttl = CONST_ETCD_FROZEN_TTL
def announce_self_waiting(self, expected_version):
while True:
cas_delay()
active_version, state = self.get_rdzv_state()
if state["status"] != "final" or state["version"] != expected_version:
raise EtcdRendezvousRetryImmediately()
# Increment counter to signal an additional waiting worker.
state["num_workers_waiting"] += 1
try:
active_version = self.client.test_and_set(
key=self.get_path("/rdzv/active_version"),
value=json.dumps(state),
prev_value=active_version.value,
)
return active_version
except etcd.EtcdCompareFailed:
log.info("Announce self as waiting CAS unsuccessful, retrying")
def confirm_membership(self, expected_version, this_rank):
# Compare-and-swap loop
while True:
cas_delay()
active_version, state = self.get_rdzv_state()
if state["status"] != "frozen":
raise EtcdRendezvousRetryImmediately(
"Rendezvous no longer frozen, before we confirmed. "
"Must join next one"
)
if state["version"] != expected_version:
raise EtcdRendezvousRetryImmediately(
"Rendezvous version changed. Must try join the new one."
)
this_lease_key = self.get_path(
"/rdzv/v_{}/rank_{}".format(expected_version, this_rank)
)
self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL)
state["keep_alives"].append(this_lease_key)
if len(state["keep_alives"]) == len(state["participants"]):
# Everyone confirmed (this rank is last to do so)