Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def run(config, seed, device, logdir):
set_global_seeds(seed)
env = make_env(config, seed)
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
if config['env.standardize_reward']:
env = VecStandardizeReward(env, clip=10., gamma=config['agent.gamma'])
env = VecStepInfo(env)
agent = Agent(config, env, device)
runner = EpisodeRunner(reset_on_call=False)
engine = Engine(config, agent=agent, env=env, runner=runner)
train_logs = []
checkpoint_count = 0
for i in count():
if agent.total_timestep >= config['train.timestep']:
def run(config, seed, device, logdir):
set_global_seeds(seed)
env = make_env(config, seed, 'train')
agent = Agent(config, env, device)
runner = StepRunner(reset_on_call=False)
engine = Engine(config, agent=agent, env=env, runner=runner)
train_logs = []
checkpoint_count = 0
for i in count():
if agent.total_timestep >= config['train.timestep']:
break
train_logger = engine.train(i)
train_logs.append(train_logger.logs)
if i == 0 or (i+1) % config['log.freq'] == 0:
train_logger.dump(keys=None, index=0, indent=0, border='-'*50)
if agent.total_timestep >= int(config['train.timestep']*(checkpoint_count/(config['checkpoint.num'] - 1))):
agent.checkpoint(logdir, i + 1)
def run(config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
env = make_env(config, seed)
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
if config['env.standardize_reward']:
env = VecStandardizeReward(env, clip=10., gamma=config['agent.gamma'])
env = VecStepInfo(env)
agent = Agent(config, env, device)
runner = EpisodeRunner(reset_on_call=False)
engine = Engine(config, agent=agent, env=env, runner=runner)
train_logs = []
checkpoint_count = 0
for i in count():
def run(config, seed, device, logdir):
set_global_seeds(seed)
torch.set_num_threads(1) # VERY IMPORTANT TO AVOID GETTING STUCK
print('Initializing...')
agent = Agent(config, make_env(config, seed, 'eval'), device)
es = CMAES([config['train.mu0']]*agent.num_params, config['train.std0'],
{'popsize': config['train.popsize'],
'seed': seed})
train_logs = []
checkpoint_count = 0
with Pool(processes=config['train.popsize']//config['train.worker_chunksize']) as pool:
print('Finish initialization. Training starts...')
for generation in range(config['train.generations']):
t0 = time.perf_counter()
solutions = es.ask()
data = [(config, seed, device, solution) for solution in solutions]
out = pool.map(CloudpickleWrapper(fitness), data, chunksize=config['train.worker_chunksize'])
def __call__(self, config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
if config['env.time_aware_obs']:
kwargs = {'extra_wrapper': [TimeAwareObservation]}
else:
kwargs = {}
env = make_vec_env(SerialVecEnv, make_gym_env, config['env.id'], config['train.N'], seed, monitor=True, **kwargs)
if config['eval.independent']:
eval_env = make_vec_env(SerialVecEnv, make_gym_env, config['env.id'], config['eval.N'], seed)
if config['env.clip_action']:
env = VecClipAction(env)
if config['eval.independent']:
eval_env = VecClipAction(eval_env)
if config['env.standardize']: # running averages of observation and reward
env = VecStandardize(venv=env,
use_obs=True,
def run(config, seed, device, logdir):
set_global_seeds(seed)
env = make_env(config, seed, 'train')
eval_env = make_env(config, seed, 'eval')
random_agent = RandomAgent(config, env, device)
agent = Agent(config, env, device)
runner = EpisodeRunner()
replay = ReplayBuffer(env, config['replay.capacity'], device)
engine = Engine(config, agent=agent, random_agent=random_agent, env=env, eval_env=eval_env, runner=runner, replay=replay, logdir=logdir)
train_logs, eval_logs = engine.train()
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
return None
def run(config, seed, device, logdir):
set_global_seeds(seed)
env = make_env(config, seed)
env = VecMonitor(env)
env = VecStepInfo(env)
eval_env = make_env(config, seed)
eval_env = VecMonitor(eval_env)
agent = Agent(config, env, device)
replay = ReplayBuffer(env, config['replay.capacity'], device)
engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
train_logs, eval_logs = engine.train()
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
return None
def run(config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
print('Initializing...')
agent = Agent(config, make_env(config, seed), device)
es = OpenAIES([config['train.mu0']]*agent.num_params, config['train.std0'],
{'popsize': config['train.popsize'],
'seed': seed,
'sigma_scheduler_args': config['train.sigma_scheduler_args'],
'lr': config['train.lr'],
'lr_decay': config['train.lr_decay'],
'min_lr': config['train.min_lr'],
'antithetic': config['train.antithetic'],
'rank_transform': config['train.rank_transform']})
train_logs = []
with ProcessPoolExecutor(max_workers=config['train.popsize'], initializer=initializer, initargs=(config, seed, device)) as executor:
print('Finish initialization. Training starts...')
def run(config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
env = make_env(config, seed)
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
if config['env.standardize_reward']:
env = VecStandardizeReward(env, clip=10., gamma=config['agent.gamma'])
agent = Agent(config, env, device)
runner = EpisodeRunner(reset_on_call=False)
engine = Engine(config, agent=agent, env=env, runner=runner)
train_logs = []
for i in count():
if agent.total_timestep >= config['train.timestep']:
break
def __call__(self, config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
es = ESMaster(config, ESWorker, logdir=logdir)
es()
return None