Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_run_discrete(self):
from tf2rl.algos.dqn import DQN
parser = DQN.get_argument(self.parser)
parser.set_defaults(n_warmup=1)
args, _ = parser.parse_known_args()
def env_fn():
return gym.make("CartPole-v0")
def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
return DQN(
name=name,
state_shape=env.observation_space.shape,
action_dim=env.action_space.n,
n_warmup=500,
target_replace_interval=300,
batch_size=32,
memory_capacity=memory_capacity,
discount=0.99,
import argparse
import numpy as np
import gym
import tensorflow as tf
from tf2rl.algos.apex import apex_argument, run
from tf2rl.algos.dqn import DQN
from tf2rl.misc.target_update_ops import update_target_variables
from tf2rl.networks.atari_model import AtariQFunc
if __name__ == '__main__':
parser = apex_argument()
parser = DQN.get_argument(parser)
parser.add_argument('--atari', action='store_true')
parser.add_argument('--env-name', type=str,
default="SpaceInvadersNoFrameskip-v4")
args = parser.parse_args()
if args.atari:
env_name = args.env_name
n_warmup = 50000
target_replace_interval = 10000
batch_size = 32
optimizer = tf.keras.optimizers.Adam(
learning_rate=0.0000625, epsilon=1.5e-4)
epsilon_decay_rate = int(1e6)
QFunc = AtariQFunc
else:
env_name = "CartPole-v0"
import gym
import tensorflow as tf
from tf2rl.algos.dqn import DQN
from tf2rl.envs.atari_wrapper import wrap_dqn
from tf2rl.experiments.trainer import Trainer
from tf2rl.networks.atari_model import AtariQFunc as QFunc
if __name__ == '__main__':
parser = Trainer.get_argument()
parser = DQN.get_argument(parser)
parser.add_argument("--replay-buffer-size", type=int, default=int(1e6))
parser.add_argument('--env-name', type=str,
default="SpaceInvadersNoFrameskip-v4")
parser.set_defaults(episode_max_steps=108000)
parser.set_defaults(test_interval=10000)
parser.set_defaults(max_steps=int(1e9))
parser.set_defaults(save_model_interval=500000)
parser.set_defaults(gpu=0)
parser.set_defaults(show_test_images=True)
args = parser.parse_args()
env = wrap_dqn(gym.make(args.env_name))
test_env = wrap_dqn(gym.make(args.env_name), reward_clipping=False)
# Following parameters are equivalent to DeepMind DQN paper
# https://www.nature.com/articles/nature14236
optimizer = tf.keras.optimizers.Adam(
import gym
from tf2rl.algos.dqn import DQN
from tf2rl.experiments.trainer import Trainer
if __name__ == '__main__':
parser = Trainer.get_argument()
parser = DQN.get_argument(parser)
parser.set_defaults(test_interval=2000)
parser.set_defaults(max_steps=100000)
parser.set_defaults(gpu=-1)
parser.set_defaults(n_warmup=500)
parser.set_defaults(batch_size=32)
parser.set_defaults(memory_capacity=int(1e4))
parser.add_argument('--env-name', type=str, default="CartPole-v0")
args = parser.parse_args()
env = gym.make(args.env_name)
test_env = gym.make(args.env_name)
policy = DQN(
enable_double_dqn=args.enable_double_dqn,
enable_dueling_dqn=args.enable_dueling_dqn,
enable_noisy_dqn=args.enable_noisy_dqn,
enable_categorical_dqn=args.enable_categorical_dqn,