Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
# here we define an imitation collector with a trivial policy
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -300 # lower the goal
net = Actor(Net(1, args.state_shape, device=args.device),
args.action_shape, args.max_action, args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='continuous')
il_test_collector = Collector(il_policy, test_envs)
train_collector.reset()
result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch,
args.step_per_epoch // 5, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
il_test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(il_policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
collector.close()
# here we define an imitation collector with a trivial policy
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -300 # lower the goal
net = Actor(Net(1, args.state_shape, device=args.device),
args.action_shape, args.max_action, args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='continuous')
il_test_collector = Collector(il_policy, test_envs)
train_collector.reset()
result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch,
args.step_per_epoch // 5, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
il_test_collector.close()
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
assert b5.b.d[0] == b5_dict[0]['b']['d']
assert b5.b.d[1] == 0.0
# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))
def test_batch_from_to_numpy_without_copy():
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch.a.__array_interface__['data'][0]
c_mem_addr_orig = batch.b.c.__array_interface__['data'][0]
batch.to_torch()
batch.to_numpy()
a_mem_addr_new = batch.a.__array_interface__['data'][0]
c_mem_addr_new = batch.b.c.__array_interface__['data'][0]
assert a_mem_addr_new == a_mem_addr_orig
assert c_mem_addr_new == c_mem_addr_orig
def test_batch_empty():
b5_dict = np.array([{'a': False, 'b': {'c': 2.0, 'd': 1.0}},
{'a': True, 'b': {'c': 3.0}}])
b5 = Batch(b5_dict)
b5[1] = Batch.empty(b5[0])
assert np.allclose(b5.a, [False, False])
assert np.allclose(b5.b.c, [2, 0])
assert np.allclose(b5.b.d, [1, 0])
data = Batch(a=[False, True],
b={'c': np.array([2., 'st'], dtype=np.object),
'd': [1, None],
'e': [2., float('nan')]},
c=np.array([1, 3, 4], dtype=np.int),
t=torch.tensor([4, 5, 6, 7.]))
data[-1] = Batch.empty(data[1])
assert np.allclose(data.c, [1, 3, 0])
assert np.allclose(data.a, [False, False])
assert list(data.b.c) == [2.0, None]
assert list(data.b.d) == [1, None]
assert np.allclose(data.b.e, [2, 0])
assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.]))
data[0].empty_() # which will fail in a, b.c, b.d, b.e, c
assert torch.allclose(data.t, torch.tensor([0., 5, 6, 0]))
data.empty_(index=0)
assert np.allclose(data.c, [0, 3, 0])
assert list(data.b.c) == [None, None]
assert list(data.b.d) == [None, None]
assert list(data.b.e) == [0, 0]
b0 = Batch()
b0.empty_()
assert b0.shape == []
assert b5.b.d[1] == 0.0
# test stack with incompatible keys
a = Batch(a=1, b=2, c=3)
b = Batch(a=4, b=5, d=6)
c = Batch(c=7, b=6, d=9)
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
assert batch2_from_comp.a.d.e == batch2.a.d.e
for batch_slice in [
batch2[slice(0, 1)], batch2[:1], batch2[0:]]:
assert batch_slice.a.b == batch2.a.b
assert batch_slice.a.c == batch2.a.c
assert batch_slice.a.d.e == batch2.a.d.e
batch2_sum = (batch2 + 1.0) * 2
assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2
assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2
assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2
batch3 = Batch(a={
'c': np.zeros(1),
'd': Batch(e=np.array([0.0]), f=np.array([3.0]))})
batch3.a.d[0] = {'e': 4.0}
assert batch3.a.d.e[0] == 4.0
batch3.a.d[0] = Batch(f=5.0)
assert batch3.a.d.f[0] == 5.0
with pytest.raises(KeyError):
batch3.a.d[0] = Batch(f=5.0, g=0.0)
# auto convert
batch4 = Batch(a=np.array(['a', 'b']))
assert batch4.a.dtype == np.object # auto convert to np.object
batch4.update(a=np.array(['c', 'd']))
assert list(batch4.a) == ['c', 'd']
assert batch4.a.dtype == np.object # auto convert to np.object
batch5 = Batch(a=np.array([{'index': 0}]))
assert isinstance(batch5.a, Batch)
assert np.allclose(batch5.a.index, [0])
batch5.b = np.array([{'index': 1}])
assert isinstance(batch5.b, Batch)
assert np.allclose(batch5.b.index, [1])
def test_batch_pickle():
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
assert batch.obs.a == batch_pk.obs.a
assert torch.all(batch.obs.c == batch_pk.obs.c)
assert np.all(batch.np == batch_pk.np)
assert isinstance(b12_stack.a.d.e, np.ndarray)
assert b12_stack.a.d.e.ndim == 2
# test cat with incompatible keys
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with reserved keys (values are Batch())
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
test = Batch.cat([b1, b2])
ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
b=torch.cat([torch.zeros(3, 3), b2.b]),
common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test cat with all reserved keys (values are Batch())
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
b2 = Batch(a=Batch(),
b=torch.rand(4, 3),
common=Batch(c=np.random.rand(4, 5)))
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
# test stack with empty Batch()
assert Batch.stack([Batch(), Batch(), Batch()]).is_empty()
a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch())
b = Batch(a=4, b=5, d=6, e=Batch())
c = Batch(c=7, b=6, d=9, e=Batch())
d = Batch.stack([a, b, c])
assert np.allclose(d.a, [1, 4, 0])
assert np.allclose(d.b, [2, 5, 6])
assert np.allclose(d.c, [3, 0, 7])
assert np.allclose(d.d, [0, 6, 9])
assert d.e.is_empty()
b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2], axis=-1)
assert test.a.is_empty()
assert test.b.is_empty()
assert np.allclose(test.common.c,
np.stack([b1.common.c, b2.common.c], axis=-1))
b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5)))
b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5)))
test = Batch.stack([b1, b2])
ans = Batch(a=np.stack([b1.a, np.zeros((4, 4))]),
b=torch.stack([torch.zeros(4, 6), b2.b]),
common=Batch(c=np.stack([b1.common.c, b2.common.c])))
assert np.allclose(test.a, ans.a)
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)