Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def make_env(config, seed, mode):
assert mode in ['train', 'eval']
def _make_env():
env = gym.make(config['env.id'])
env = ClipAction(env)
return env
env = make_vec_env(_make_env, 1, seed) # single environment
env = VecMonitor(env)
if mode == 'train':
env = VecStepInfo(env)
return env
def run(config, seed, device, logdir):
set_global_seeds(seed)
env = make_env(config, seed)
env = VecMonitor(env)
env = VecStepInfo(env)
eval_env = make_env(config, seed)
eval_env = VecMonitor(eval_env)
agent = Agent(config, env, device)
replay = ReplayBuffer(env, config['replay.capacity'], device)
engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
train_logs, eval_logs = engine.train()
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
return None
def run(config, seed, device, logdir):
set_global_seeds(seed)
env = make_env(config, seed)
env = VecMonitor(env)
env = VecStepInfo(env)
eval_env = make_env(config, seed)
eval_env = VecMonitor(eval_env)
agent = Agent(config, env, device)
replay = ReplayBuffer(env, config['replay.capacity'], device)
engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
train_logs, eval_logs = engine.train()
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
return None
def run(config, seed, device, logdir):
set_global_seeds(seed)
env = make_env(config, seed)
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
if config['env.standardize_reward']:
env = VecStandardizeReward(env, clip=10., gamma=config['agent.gamma'])
env = VecStepInfo(env)
agent = Agent(config, env, device)
runner = EpisodeRunner(reset_on_call=False)
engine = Engine(config, agent=agent, env=env, runner=runner)
train_logs = []
checkpoint_count = 0
for i in count():
if agent.total_timestep >= config['train.timestep']:
break
train_logger = engine.train(i)
train_logs.append(train_logger.logs)
def run(config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
env = make_env(config, seed)
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=5.)
if config['env.standardize_reward']:
env = VecStandardizeReward(env, clip=10., gamma=config['agent.gamma'])
agent = Agent(config, env, device)
runner = EpisodeRunner(reset_on_call=False)
engine = Engine(config, agent=agent, env=env, runner=runner)
train_logs = []
for i in count():
if agent.total_timestep >= config['train.timestep']:
break
train_logger = engine.train(i)
train_logs.append(train_logger.logs)
if i == 0 or (i+1) % config['log.freq'] == 0:
train_logger.dump(keys=None, index=0, indent=0, border='-'*50)
def run(config, seed, device, logdir):
set_global_seeds(seed)
env = make_env(config, seed)
env = VecMonitor(env)
env = VecStepInfo(env)
eval_env = make_env(config, seed)
eval_env = VecMonitor(eval_env)
agent = Agent(config, env, device)
replay = ReplayBuffer(env, config['replay.capacity'], device)
engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
train_logs, eval_logs = engine.train()
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
return None
def run(config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
env = make_env(config, seed)
env = VecMonitor(env)
eval_env = make_env(config, seed)
eval_env = VecMonitor(eval_env)
agent = Agent(config, env, device)
replay = ReplayBuffer(config['replay.capacity'], device)
engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
train_logs, eval_logs = engine.train()
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
return None
def initializer(config, seed, device):
global env
env = make_env(config, seed)
env = VecMonitor(env)
if config['env.standardize_obs']:
env = VecStandardizeObservation(env, clip=10.)
global agent
agent = Agent(config, env, device)
def run(config, seed, device):
set_global_seeds(seed)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)
env = make_env(config, seed)
env = VecMonitor(env)
eval_env = make_env(config, seed)
eval_env = VecMonitor(eval_env)
agent = Agent(config, env, device)
replay = ReplayBuffer(env, config['replay.capacity'], device)
engine = Engine(config, agent=agent, env=env, eval_env=eval_env, replay=replay, logdir=logdir)
train_logs, eval_logs = engine.train()
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')
return None
def make_env(config, seed, mode):
assert mode in ['train', 'eval']
def _make_env():
env = gym.make(config['env.id'])
env = ClipAction(env)
return env
env = make_vec_env(_make_env, 1, seed) # single environment
env = VecMonitor(env)
if mode == 'train':
env = VecStepInfo(env)
return env