Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# LSTM
a = nn.LSTM(2, 3, 2)
ortho_init(a, weight_scale=1000., constant_bias=10.)
assert a.weight_hh_l0.max().item() > 50.
assert a.weight_hh_l1.max().item() > 50.
assert a.weight_ih_l0.max().item() > 50.
assert a.weight_ih_l1.max().item() > 50.
assert np.allclose(a.bias_hh_l0.detach().numpy(), 10.)
assert np.allclose(a.bias_hh_l1.detach().numpy(), 10.)
assert np.allclose(a.bias_ih_l0.detach().numpy(), 10.)
assert np.allclose(a.bias_ih_l1.detach().numpy(), 10.)
# LSTMCell
a = nn.LSTMCell(3, 2)
ortho_init(a, weight_scale=1000., constant_bias=10.)
assert a.weight_hh.max().item() > 50.
assert a.weight_ih.max().item() > 50.
assert np.allclose(a.bias_hh.detach().numpy(), 10.)
assert np.allclose(a.bias_ih.detach().numpy(), 10.)
self.feature_network = MLP(config, env, device, **kwargs)
feature_dim = config['nn.sizes'][-1]
if isinstance(env.action_space, Discrete):
self.action_head = CategoricalHead(feature_dim, env.action_space.n, device, **kwargs)
elif isinstance(env.action_space, Box):
self.action_head = DiagGaussianHead(feature_dim,
flatdim(env.action_space),
device,
config['agent.std0'],
config['agent.std_style'],
config['agent.std_range'],
config['agent.beta'],
**kwargs)
self.V_head = nn.Linear(feature_dim, 1).to(device)
ortho_init(self.V_head, weight_scale=1.0, constant_bias=0.0)
self.total_timestep = 0
self.optimizer = optim.Adam(self.parameters(), lr=config['agent.lr'])
if config['agent.use_lr_scheduler']:
self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep'], min_lr=1e-8)
def __init__(self, config, env, device, **kwargs):
super().__init__(config, env, device, **kwargs)
feature_dim = config['nn.sizes'][-1]
self.feature_network = MLP(config, env, device, **kwargs)
if isinstance(env.action_space, spaces.Discrete):
self.action_head = CategoricalHead(feature_dim, env.action_space.n, device, **kwargs)
elif isinstance(env.action_space, spaces.Box):
self.action_head = DiagGaussianHead(feature_dim, spaces.flatdim(env.action_space), device, config['agent.std0'], **kwargs)
self.V_head = nn.Linear(feature_dim, 1)
ortho_init(self.V_head, weight_scale=1.0, constant_bias=0.0)
self.V_head = self.V_head.to(device) # reproducible between CPU/GPU, ortho_init behaves differently
self.total_timestep = 0
self.optimizer = optim.Adam(self.parameters(), lr=config['agent.lr'])
if config['agent.use_lr_scheduler']:
self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep'], min_lr=1e-8)
def init_params(self, config):
for layer in self.feature_layers:
ortho_init(layer, nonlinearity='relu', constant_bias=0.0)
ortho_init(self.mean_head, weight_scale=0.01, constant_bias=0.0)
ortho_init(self.logvar_head, weight_scale=0.01, constant_bias=0.0)
ortho_init(self.mean_head, weight_scale=0.01, constant_bias=0.0)
self.logvar_head = nn.Linear(256, config['nn.z_dim'])
ortho_init(self.logvar_head, weight_scale=0.01, constant_bias=0.0)
self.decoder_fc = nn.Linear(config['nn.z_dim'], 256)
ortho_init(self.decoder_fc, nonlinearity='relu', constant_bias=0.0)
self.decoder = make_transposed_cnn(input_channel=64,
channels=[64, 64, 64],
kernels=[4, 4, 4],
strides=[2, 1, 1],
paddings=[0, 0, 0],
output_paddings=[0, 0, 0])
for layer in self.decoder:
ortho_init(layer, nonlinearity='relu', constant_bias=0.0)
self.x_head = nn.Linear(9216, 784)
ortho_init(self.x_head, nonlinearity='sigmoid', constant_bias=0.0)
self.to(device)
self.total_iter = 0
super().__init__(**kwargs)
self.config = config
self.device = device
self.encoder = make_cnn(input_channel=1,
channels=[64, 64, 64],
kernels=[4, 4, 4],
strides=[2, 2, 1],
paddings=[0, 0, 0])
for layer in self.encoder:
ortho_init(layer, nonlinearity='relu', constant_bias=0.0)
self.mean_head = nn.Linear(256, config['nn.z_dim'])
ortho_init(self.mean_head, weight_scale=0.01, constant_bias=0.0)
self.logvar_head = nn.Linear(256, config['nn.z_dim'])
ortho_init(self.logvar_head, weight_scale=0.01, constant_bias=0.0)
self.decoder_fc = nn.Linear(config['nn.z_dim'], 256)
ortho_init(self.decoder_fc, nonlinearity='relu', constant_bias=0.0)
self.decoder = make_transposed_cnn(input_channel=64,
channels=[64, 64, 64],
kernels=[4, 4, 4],
strides=[2, 1, 1],
paddings=[0, 0, 0],
output_paddings=[0, 0, 0])
for layer in self.decoder:
ortho_init(layer, nonlinearity='relu', constant_bias=0.0)
self.x_head = nn.Linear(9216, 784)
ortho_init(self.x_head, nonlinearity='sigmoid', constant_bias=0.0)
self.to(device)
self.total_iter = 0
def __init__(self, config, env, device, **kwargs):
super().__init__(**kwargs)
self.config = config
self.env = env
self.device = device
self.feature_layers = make_fc(spaces.flatdim(env.observation_space), config['nn.sizes'])
for layer in self.feature_layers:
ortho_init(layer, nonlinearity='relu', constant_bias=0.0)
feature_dim = config['nn.sizes'][-1]
self.V_head = nn.Linear(feature_dim, 1)
ortho_init(self.V_head, weight_scale=1.0, constant_bias=0.0)
self.to(self.device)
def __init__(self, config, env, device, **kwargs):
super().__init__(**kwargs)
self.config = config
self.env = env
self.device = device
self.feature_layers = make_fc(flatdim(env.observation_space), config['nn.sizes'])
for layer in self.feature_layers:
ortho_init(layer, nonlinearity='relu', constant_bias=0.0)
self.layer_norms = nn.ModuleList([nn.LayerNorm(hidden_size) for hidden_size in config['nn.sizes']])
self.to(self.device)
def __init__(self, in_features, out_features, num_density, **kwargs):
super().__init__(**kwargs)
self.in_features = in_features
self.out_features = out_features
self.num_density = num_density
self.pi_head = nn.Linear(in_features, out_features*num_density)
ortho_init(self.pi_head, weight_scale=0.01, constant_bias=0.0)
self.mean_head = nn.Linear(in_features, out_features*num_density)
ortho_init(self.mean_head, weight_scale=0.01, constant_bias=0.0)
self.logvar_head = nn.Linear(in_features, out_features*num_density)
ortho_init(self.logvar_head, weight_scale=0.01, constant_bias=0.0)
def init_params(self, config):
for layer in self.feature_layers:
ortho_init(layer, nonlinearity='leaky_relu', constant_bias=0.0)