Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
parser.set_defaults(batch_size=100)
parser.set_defaults(n_warmup=10000)
args = parser.parse_args()
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
policy = BiResDDPG(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
gpu=args.gpu,
eta=args.eta,
memory_capacity=args.memory_capacity,
max_action=env.action_space.high[0],
batch_size=args.batch_size,
n_warmup=args.n_warmup)
trainer = Trainer(policy, env, args, test_env=test_env)
trainer()
enable_categorical_dqn=args.enable_categorical_dqn,
state_shape=env.observation_space.shape,
action_dim=env.action_space.n,
n_warmup=50000,
target_replace_interval=10000,
batch_size=32,
memory_capacity=args.replay_buffer_size,
discount=0.99,
epsilon=1.,
epsilon_min=0.1,
epsilon_decay_step=int(1e6),
optimizer=optimizer,
update_interval=4,
q_func=QFunc,
gpu=args.gpu)
trainer = Trainer(policy, env, args, test_env=test_env)
trainer()
parser.set_defaults(batch_size=100)
parser.set_defaults(n_warmup=10000)
args = parser.parse_args()
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
policy = SAC(
state_shape=env.observation_space.shape,
action_dim=env.action_space.high.size,
gpu=args.gpu,
memory_capacity=args.memory_capacity,
max_action=env.action_space.high[0],
batch_size=args.batch_size,
n_warmup=args.n_warmup,
auto_alpha=args.auto_alpha)
trainer = Trainer(policy, env, args, test_env=test_env)
trainer()
import time
import numpy as np
import tensorflow as tf
from tf2rl.misc.get_replay_buffer import get_replay_buffer
from tf2rl.experiments.trainer import Trainer
class IRLTrainer(Trainer):
def __init__(
self,
policy,
env,
args,
irl,
expert_obs,
expert_next_obs,
expert_act,
test_env=None):
self._irl = irl
args.dir_suffix = self._irl.policy_name + args.dir_suffix
super().__init__(policy, env, args, test_env)
# TODO: Add assertion to check dimention of expert demos and current policy, env is the same
self._expert_obs = expert_obs
self._expert_next_obs = expert_next_obs
def get_argument(parser=None):
parser = Trainer.get_argument(parser)
parser.add_argument('--expert-path-dir', default=None,
help='Path to directory that contains expert trajectories')
return parser