Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _run_sync(self):
decay = self.ema_baseline_decay
for i in tqdm(range(self.num_trials // self.controller_batch_size + 1)):
with mx.autograd.record():
# sample controller_batch_size number of configurations
batch_size = self.num_trials % self.num_trials \
if i == self.num_trials // self.controller_batch_size \
else self.controller_batch_size
if batch_size == 0: continue
configs, log_probs, entropies = self.controller.sample(
batch_size, with_details=True)
# schedule the training tasks and gather the reward
rewards = self.sync_schedule_tasks(configs)
# substract baseline
if self.baseline is None:
self.baseline = rewards[0]
avg_rewards = mx.nd.array([reward - self.baseline for reward in rewards],
ctx=self.controller.context)
# EMA baseline
metric = get_metric_instance(args.metric)
def train(epoch):
for i, batch in enumerate(train_data):
default_train_fn(net, batch, batch_size, args.loss, trainer, batch_fn, ctx)
mx.nd.waitall()
def test(epoch):
metric.reset()
for i, batch in enumerate(val_data):
default_val_fn(net, batch, batch_fn, metric, ctx)
_, reward = metric.get()
reporter(epoch=epoch, classification_reward=reward)
return reward
tbar = tqdm(range(1, args.epochs + 1))
for epoch in tbar:
train(epoch)
if not args.final_fit:
reward = test(epoch)
tbar.set_description('[Epoch {}] Validation: {:.3f}'.format(epoch, reward))
if args.final_fit:
return {'model_params': collect_params(net),
'num_classes': num_classes}
def run(self):
self._prefetch_controller()
tq = tqdm(range(self.epochs))
for epoch in tq:
# for recordio data
if hasattr(self.train_data, 'reset'): self.train_data.reset()
tbar = tqdm(self.train_data)
idx = 0
for batch in tbar:
# sample network configuration
config = self.controller.pre_sample()[0]
self.supernet.sample(**config)
self.train_fn(self.supernet, batch, **self.train_args)
mx.nd.waitall()
if epoch >= self.warmup_epochs and (idx % self.update_arch_frequency) == 0:
self.train_controller()
if self.plot_frequency > 0 and idx % self.plot_frequency == 0 and in_ipynb():
graph = self.supernet.graph
graph.attr(rankdir='LR', size='8,3')
def evaluate(loader_dev, metric, segment):
"""Evaluate the model on validation dataset."""
#logger.info('Now we are doing evaluation on %s with %s.', segment, ctx)
metric.reset()
step_loss = 0
tbar = tqdm(loader_dev)
for batch_id, seqs in enumerate(tbar):
input_ids, valid_length, segment_ids, label = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
label = label.as_in_context(ctx)
if use_roberta:
out = model(input_ids, valid_length)
else:
out = model(input_ids, segment_ids.as_in_context(ctx), valid_length)
ls = loss_function(out, label).mean()
step_loss += ls.asscalar()
metric.update([label], [out])
if (batch_id + 1) % (args.log_interval) == 0:
log_eval(batch_id, len(loader_dev), metric, step_loss, args.log_interval, tbar)
def validation(self):
if hasattr(self.val_data, 'reset'): self.val_data.reset()
# data iter, avoid memory leak
it = iter(self.val_data)
if hasattr(it, 'reset_sample_times'): it.reset_sample_times()
tbar = tqdm(it)
# update network arc
config = self.controller.inference()
self.supernet.sample(**config)
metric = mx.metric.Accuracy()
for batch in tbar:
self.eval_fn(self.supernet, batch, metric=metric, **self.val_args)
reward = metric.get()[1]
tbar.set_description('Val Acc: {}'.format(reward))
self.val_acc = reward
self.training_history.append(reward)
def run(self):
tq = tqdm(range(self.epochs))
for epoch in tq:
# for recordio data
tbar = tqdm(self.train_data)
idx = 0
for (data, label) in tbar:
# sample network configuration
config = self.controller.pre_sample()[0]
self.supernet.sample(**config)
self.train_fn(self.supernet, data, label, idx, epoch, **self.train_args)
if epoch >= self.warmup_epochs and (idx % self.update_arch_frequency) == 0:
self.train_controller()
if self.plot_frequency > 0 and idx % self.plot_frequency == 0 and in_ipynb():
graph = self.supernet.graph
graph.attr(rankdir='LR', size='8,3')
tbar.set_svg(graph._repr_svg_())
tbar.set_description('avg reward: {:.2f}'.format(self.baseline))
def validation(self):
# data iter
tbar = tqdm(self.val_data)
# update network arc
config = self.controller.inference()
self.supernet.sample(**config)
metric = AverageMeter()
for (data, label) in tbar:
acc = self.eval_fn(self.supernet, data, label, metric=metric, **self.val_args)
reward = metric.avg
tbar.set_description('Acc: {}'.format(reward.item()))
self.val_acc = reward.item()
self.training_history.append(reward)
def run(self):
tq = tqdm(range(self.epochs))
for epoch in tq:
# for recordio data
tbar = tqdm(self.train_data)
idx = 0
for (data, label) in tbar:
# sample network configuration
config = self.controller.pre_sample()[0]
self.supernet.sample(**config)
self.train_fn(self.supernet, data, label, idx, epoch, **self.train_args)
if epoch >= self.warmup_epochs and (idx % self.update_arch_frequency) == 0:
self.train_controller()
if self.plot_frequency > 0 and idx % self.plot_frequency == 0 and in_ipynb():
graph = self.supernet.graph
graph.attr(rankdir='LR', size='8,3')
tbar.set_svg(graph._repr_svg_())
tbar.set_description('avg reward: {:.2f}'.format(self.baseline))
idx += 1
self.validation()
self.save()