Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _training_schedule(config, params):
config.train_steps = int(params.get('train_steps', 50000))
config.test_steps = int(params.get('test_steps', 50))
config.max_steps = int(params.get('max_steps', 5e7))
config.train_log_every = config.train_steps
config.train_checkpoint_every = None
config.test_checkpoint_every = int(
params.get('checkpoint_every', 10 * config.test_steps))
config.checkpoint_to_load = None
config.savers = [tools.AttrDict(exclude=(r'.*_temporary.*',))]
config.print_metrics_every = config.train_steps // 10
config.train_dir = os.path.join(params.logdir, 'train_episodes')
config.test_dir = os.path.join(params.logdir, 'test_episodes')
config.random_collects = _initial_collection(config, params)
config.train_collects = _active_collection(
params.get('train_collects', [{}]), dict(
prefix='train',
save_episode_dir=config.train_dir,
action_noise=params.get('train_action_noise', 0.3),
), config, params)
config.test_collects = _active_collection(
params.get('test_collects', [{}]), dict(
prefix='test',
save_episode_dir=config.test_dir,
action_noise=0.0,
), config, params)
def test_dummy_isolate_none(self):
args = tools.AttrDict(
logdir=self.get_temp_dir(),
num_runs=1,
config='debug',
params=tools.AttrDict(
task='dummy',
isolate_envs='none',
max_steps=30),
ping_every=0,
resume_runs=False)
try:
tf.app.run(lambda _: train.main(args), [sys.argv[0]])
except SystemExit:
pass
def test_dm_control_isolate_none(self):
args = tools.AttrDict(
logdir=self.get_temp_dir(),
num_runs=1,
config='debug',
params=tools.AttrDict(
task='cup_catch',
isolate_envs='none',
max_steps=30),
ping_every=0,
resume_runs=False)
try:
tf.app.run(lambda _: train.main(args), [sys.argv[0]])
except SystemExit:
pass
def test_dummy_isolate_thread(self):
args = tools.AttrDict(
logdir=self.get_temp_dir(),
num_runs=1,
config='debug',
params=tools.AttrDict(
task='dummy',
isolate_envs='thread',
max_steps=30),
ping_every=0,
resume_runs=False)
try:
tf.app.run(lambda _: train.main(args), [sys.argv[0]])
except SystemExit:
pass
name='main',
batch_size=1,
horizon=params.get('planner_horizon', 12),
objective=params.get('collect_objective', 'reward'),
after=params.get('collect_every', 5000),
every=params.get('collect_every', 5000),
until=-1,
action_noise=0.0,
action_noise_ramp=params.get('action_noise_ramp', 0),
action_noise_min=params.get('action_noise_min', 0.0),
)
defs.update(defaults)
sims = tools.AttrDict(_unlocked=True)
for task in config.tasks:
for collect in collects:
collect = tools.AttrDict(collect, _defaults=defs)
sim = _define_simulation(
task, config, params, collect.horizon, collect.batch_size,
collect.objective)
sim.unlock()
sim.save_episode_dir = collect.save_episode_dir
sim.steps_after = int(collect.after)
sim.steps_every = int(collect.every)
sim.steps_until = int(collect.until)
sim.exploration = tools.AttrDict(
scale=collect.action_noise,
schedule=tools.bind(
tools.schedule.linear,
ramp=collect.action_noise_ramp,
min=collect.action_noise_min,
))
name = '{}_{}_{}'.format(collect.prefix, collect.name, task.name)
parser.add_argument(
'--num_runs', type=int, default=1)
parser.add_argument(
'--config', default='default',
help='Select a configuration function from scripts/configs.py.')
parser.add_argument(
'--params', default='{}',
help='YAML formatted dictionary to be used by the config.')
parser.add_argument(
'--ping_every', type=int, default=0,
help='Used to prevent conflicts between multiple workers; 0 to disable.')
parser.add_argument(
'--resume_runs', type=boolean, default=True,
help='Whether to resume unfinished runs in the log directory.')
args_, remaining = parser.parse_known_args()
args_.params = tools.AttrDict(yaml.safe_load(args_.params.replace('#', ',')))
args_.logdir = args_.logdir and os.path.expanduser(args_.logdir)
remaining.insert(0, sys.argv[0])
tf.app.run(lambda _: main(args_), remaining)
def define_model(data, trainer, config):
tf.logging.info('Build TensorFlow compute graph.')
dependencies = []
cleanups = []
step = trainer.step
global_step = trainer.global_step
phase = trainer.phase
# Instantiate network blocks.
cell = config.cell()
kwargs = dict(create_scope_now_=True)
encoder = tf.make_template('encoder', config.encoder, **kwargs)
heads = tools.AttrDict(_unlocked=True)
dummy_features = cell.features_from_state(cell.zero_state(1, tf.float32))
for key, head in config.heads.items():
name = 'head_{}'.format(key)
kwargs = dict(create_scope_now_=True)
if key in data:
kwargs['data_shape'] = data[key].shape[2:].as_list()
elif key == 'action_target':
kwargs['data_shape'] = data['action'].shape[2:].as_list()
heads[key] = tf.make_template(name, head, **kwargs)
heads[key](dummy_features) # Initialize weights.
# Apply and optimize model.
embedded = encoder(data)
with tf.control_dependencies(dependencies):
embedded = tf.identity(embedded)
graph = tools.AttrDict(locals())
def _initial_collection(config, params):
num_seed_episodes = params.get('num_seed_episodes', 5)
sims = tools.AttrDict(_unlocked=True)
for task in config.tasks:
sims['train-' + task.name] = tools.AttrDict(
task=task,
save_episode_dir=config.train_dir,
num_episodes=num_seed_episodes)
sims['test-' + task.name] = tools.AttrDict(
task=task,
save_episode_dir=config.test_dir,
num_episodes=num_seed_episodes)
return sims