How to use the planet.tools.AttrDict function in planet

To help you get started, we’ve selected a few planet examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google-research / planet / planet / scripts / configs.py View on Github external
def _training_schedule(config, params):
  config.train_steps = int(params.get('train_steps', 50000))
  config.test_steps = int(params.get('test_steps', 50))
  config.max_steps = int(params.get('max_steps', 5e7))
  config.train_log_every = config.train_steps
  config.train_checkpoint_every = None
  config.test_checkpoint_every = int(
      params.get('checkpoint_every', 10 * config.test_steps))
  config.checkpoint_to_load = None
  config.savers = [tools.AttrDict(exclude=(r'.*_temporary.*',))]
  config.print_metrics_every = config.train_steps // 10
  config.train_dir = os.path.join(params.logdir, 'train_episodes')
  config.test_dir = os.path.join(params.logdir, 'test_episodes')
  config.random_collects = _initial_collection(config, params)
  config.train_collects = _active_collection(
      params.get('train_collects', [{}]), dict(
          prefix='train',
          save_episode_dir=config.train_dir,
          action_noise=params.get('train_action_noise', 0.3),
      ), config, params)
  config.test_collects = _active_collection(
      params.get('test_collects', [{}]), dict(
          prefix='test',
          save_episode_dir=config.test_dir,
          action_noise=0.0,
      ), config, params)
github google-research / planet / planet / scripts / test_planet.py View on Github external
def test_dummy_isolate_none(self):
    args = tools.AttrDict(
        logdir=self.get_temp_dir(),
        num_runs=1,
        config='debug',
        params=tools.AttrDict(
            task='dummy',
            isolate_envs='none',
            max_steps=30),
        ping_every=0,
        resume_runs=False)
    try:
      tf.app.run(lambda _: train.main(args), [sys.argv[0]])
    except SystemExit:
      pass
github google-research / planet / planet / scripts / test_planet.py View on Github external
def test_dm_control_isolate_none(self):
    args = tools.AttrDict(
        logdir=self.get_temp_dir(),
        num_runs=1,
        config='debug',
        params=tools.AttrDict(
            task='cup_catch',
            isolate_envs='none',
            max_steps=30),
        ping_every=0,
        resume_runs=False)
    try:
      tf.app.run(lambda _: train.main(args), [sys.argv[0]])
    except SystemExit:
      pass
github google-research / planet / planet / scripts / test_planet.py View on Github external
def test_dummy_isolate_thread(self):
    args = tools.AttrDict(
        logdir=self.get_temp_dir(),
        num_runs=1,
        config='debug',
        params=tools.AttrDict(
            task='dummy',
            isolate_envs='thread',
            max_steps=30),
        ping_every=0,
        resume_runs=False)
    try:
      tf.app.run(lambda _: train.main(args), [sys.argv[0]])
    except SystemExit:
      pass
github google-research / planet / planet / scripts / configs.py View on Github external
name='main',
      batch_size=1,
      horizon=params.get('planner_horizon', 12),
      objective=params.get('collect_objective', 'reward'),
      after=params.get('collect_every', 5000),
      every=params.get('collect_every', 5000),
      until=-1,
      action_noise=0.0,
      action_noise_ramp=params.get('action_noise_ramp', 0),
      action_noise_min=params.get('action_noise_min', 0.0),
  )
  defs.update(defaults)
  sims = tools.AttrDict(_unlocked=True)
  for task in config.tasks:
    for collect in collects:
      collect = tools.AttrDict(collect, _defaults=defs)
      sim = _define_simulation(
          task, config, params, collect.horizon, collect.batch_size,
          collect.objective)
      sim.unlock()
      sim.save_episode_dir = collect.save_episode_dir
      sim.steps_after = int(collect.after)
      sim.steps_every = int(collect.every)
      sim.steps_until = int(collect.until)
      sim.exploration = tools.AttrDict(
          scale=collect.action_noise,
          schedule=tools.bind(
              tools.schedule.linear,
              ramp=collect.action_noise_ramp,
              min=collect.action_noise_min,
          ))
      name = '{}_{}_{}'.format(collect.prefix, collect.name, task.name)
github google-research / planet / planet / scripts / train.py View on Github external
parser.add_argument(
      '--num_runs', type=int, default=1)
  parser.add_argument(
      '--config', default='default',
      help='Select a configuration function from scripts/configs.py.')
  parser.add_argument(
      '--params', default='{}',
      help='YAML formatted dictionary to be used by the config.')
  parser.add_argument(
      '--ping_every', type=int, default=0,
      help='Used to prevent conflicts between multiple workers; 0 to disable.')
  parser.add_argument(
      '--resume_runs', type=boolean, default=True,
      help='Whether to resume unfinished runs in the log directory.')
  args_, remaining = parser.parse_known_args()
  args_.params = tools.AttrDict(yaml.safe_load(args_.params.replace('#', ',')))
  args_.logdir = args_.logdir and os.path.expanduser(args_.logdir)
  remaining.insert(0, sys.argv[0])
  tf.app.run(lambda _: main(args_), remaining)
github google-research / planet / planet / training / define_model.py View on Github external
def define_model(data, trainer, config):
  tf.logging.info('Build TensorFlow compute graph.')
  dependencies = []
  cleanups = []
  step = trainer.step
  global_step = trainer.global_step
  phase = trainer.phase

  # Instantiate network blocks.
  cell = config.cell()
  kwargs = dict(create_scope_now_=True)
  encoder = tf.make_template('encoder', config.encoder, **kwargs)
  heads = tools.AttrDict(_unlocked=True)
  dummy_features = cell.features_from_state(cell.zero_state(1, tf.float32))
  for key, head in config.heads.items():
    name = 'head_{}'.format(key)
    kwargs = dict(create_scope_now_=True)
    if key in data:
      kwargs['data_shape'] = data[key].shape[2:].as_list()
    elif key == 'action_target':
      kwargs['data_shape'] = data['action'].shape[2:].as_list()
    heads[key] = tf.make_template(name, head, **kwargs)
    heads[key](dummy_features)  # Initialize weights.

  # Apply and optimize model.
  embedded = encoder(data)
  with tf.control_dependencies(dependencies):
    embedded = tf.identity(embedded)
  graph = tools.AttrDict(locals())
github google-research / planet / planet / scripts / configs.py View on Github external
def _initial_collection(config, params):
  num_seed_episodes = params.get('num_seed_episodes', 5)
  sims = tools.AttrDict(_unlocked=True)
  for task in config.tasks:
    sims['train-' + task.name] = tools.AttrDict(
        task=task,
        save_episode_dir=config.train_dir,
        num_episodes=num_seed_episodes)
    sims['test-' + task.name] = tools.AttrDict(
        task=task,
        save_episode_dir=config.test_dir,
        num_episodes=num_seed_episodes)
  return sims