How to use the planet.tools.bind function in planet

To help you get started, we’ve selected a few planet examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github google-research / planet / planet / scripts / configs.py View on Github external
def _data_processing(config, params):
  config.batch_shape = params.get('batch_shape', (50, 50))
  config.num_chunks = params.get('num_chunks', 1)
  image_bits = params.get('image_bits', 5)
  config.preprocess_fn = tools.bind(
      tools.preprocess.preprocess, bits=image_bits)
  config.postprocess_fn = tools.bind(
      tools.preprocess.postprocess, bits=image_bits)
  config.open_loop_context = 5
  config.data_reader = tools.numpy_episodes.episode_reader
  config.data_loader = {
      'cache': tools.bind(
          tools.numpy_episodes.cache_loader,
          every=params.get('loader_every', 1000)),
      'recent': tools.bind(
          tools.numpy_episodes.recent_loader,
          every=params.get('loader_every', 1000)),
      'reload': tools.numpy_episodes.reload_loader,
      'dummy': tools.numpy_episodes.dummy_loader,
  }[params.get('loader', 'recent')]
  config.bound_action = tools.bind(
      tools.bound_action,
      strategy=params.get('bound_action', 'clip'))
  return config
github google-research / planet / planet / scripts / configs.py View on Github external
def _data_processing(config, params):
  config.batch_shape = params.get('batch_shape', (50, 50))
  config.num_chunks = params.get('num_chunks', 1)
  image_bits = params.get('image_bits', 5)
  config.preprocess_fn = tools.bind(
      tools.preprocess.preprocess, bits=image_bits)
  config.postprocess_fn = tools.bind(
      tools.preprocess.postprocess, bits=image_bits)
  config.open_loop_context = 5
  config.data_reader = tools.numpy_episodes.episode_reader
  config.data_loader = {
      'cache': tools.bind(
          tools.numpy_episodes.cache_loader,
          every=params.get('loader_every', 1000)),
      'recent': tools.bind(
          tools.numpy_episodes.recent_loader,
          every=params.get('loader_every', 1000)),
      'reload': tools.numpy_episodes.reload_loader,
      'dummy': tools.numpy_episodes.dummy_loader,
  }[params.get('loader', 'recent')]
  config.bound_action = tools.bind(
github google-research / planet / planet / scripts / configs.py View on Github external
def _data_processing(config, params):
  config.batch_shape = params.get('batch_shape', (50, 50))
  config.num_chunks = params.get('num_chunks', 1)
  image_bits = params.get('image_bits', 5)
  config.preprocess_fn = tools.bind(
      tools.preprocess.preprocess, bits=image_bits)
  config.postprocess_fn = tools.bind(
      tools.preprocess.postprocess, bits=image_bits)
  config.open_loop_context = 5
  config.data_reader = tools.numpy_episodes.episode_reader
  config.data_loader = {
      'cache': tools.bind(
          tools.numpy_episodes.cache_loader,
          every=params.get('loader_every', 1000)),
      'recent': tools.bind(
          tools.numpy_episodes.recent_loader,
          every=params.get('loader_every', 1000)),
      'reload': tools.numpy_episodes.reload_loader,
      'dummy': tools.numpy_episodes.dummy_loader,
  }[params.get('loader', 'recent')]
  config.bound_action = tools.bind(
      tools.bound_action,
      strategy=params.get('bound_action', 'clip'))
  return config
github google-research / planet / planet / scripts / configs.py View on Github external
def _define_simulation(
    task, config, params, horizon, batch_size, objective='reward',
    rewards=False):
  planner = params.get('planner', 'cem')
  if planner == 'cem':
    planner_fn = tools.bind(
        control.planning.cross_entropy_method,
        amount=params.get('planner_amount', 1000),
        iterations=params.get('planner_iterations', 10),
        topk=params.get('planner_topk', 100),
        horizon=horizon)
  else:
    raise NotImplementedError(planner)
  return tools.AttrDict(
      task=task,
      num_agents=batch_size,
      planner=planner_fn,
      objective=tools.bind(getattr(objectives_lib, objective), params=params))
github google-research / planet / planet / scripts / configs.py View on Github external
rewards=False):
  planner = params.get('planner', 'cem')
  if planner == 'cem':
    planner_fn = tools.bind(
        control.planning.cross_entropy_method,
        amount=params.get('planner_amount', 1000),
        iterations=params.get('planner_iterations', 10),
        topk=params.get('planner_topk', 100),
        horizon=horizon)
  else:
    raise NotImplementedError(planner)
  return tools.AttrDict(
      task=task,
      num_agents=batch_size,
      planner=planner_fn,
      objective=tools.bind(getattr(objectives_lib, objective), params=params))
github google-research / planet / planet / scripts / configs.py View on Github external
activation=config.activation)
  config.encoder = network.encoder
  config.decoder = network.decoder
  config.heads = tools.AttrDict(_unlocked=True)
  config.heads.image = config.decoder
  size = params.get('model_size', 200)
  state_size = params.get('state_size', 30)
  model = params.get('model', 'rssm')
  if model == 'ssm':
    config.cell = tools.bind(
        models.SSM, state_size, size,
        params.get('mean_only', False),
        config.activation,
        params.get('min_stddev', 1e-1))
  elif model == 'rssm':
    config.cell = tools.bind(
        models.RSSM, state_size, size, size,
        params.get('future_rnn', True),
        params.get('mean_only', False),
        params.get('min_stddev', 1e-1),
        config.activation,
        params.get('model_layers', 1))
  elif params.model == 'drnn':
    config.cell = tools.bind(
        models.DRNN, state_size, size, size,
        params.get('mean_only', False),
        params.get('min_stddev', 1e-1), config.activation,
        params.get('drnn_encoder_to_decoder', False),
        params.get('drnn_sample_to_sample', True),
        params.get('drnn_sample_to_encoder', True),
        params.get('drnn_decoder_to_encoder', False),
        params.get('drnn_decoder_to_sample', True),