How to use the tianshou.data.Batch.cat function in tianshou

To help you get started, we’ve selected a few tianshou examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github thu-ml / tianshou / test / base / test_batch.py View on Github external
def test_batch_cat_and_stack():
    # test cat with compatible keys
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b12_cat_out = Batch.cat([b1, b2])
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1

    a = Batch(a=Batch(a=np.random.randn(3, 4)))
    assert np.allclose(
        np.concatenate([a.a.a, a.a.a]),
        Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)

    # test cat with lens infer
    a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
    b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
    ans = Batch.cat([a, b, a])
github thu-ml / tianshou / test / base / test_batch.py View on Github external
def test_batch_cat_and_stack():
    # test cat with compatible keys
    b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}])
    b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}])
    b12_cat_out = Batch.cat([b1, b2])
    b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1

    a = Batch(a=Batch(a=np.random.randn(3, 4)))
    assert np.allclose(
        np.concatenate([a.a.a, a.a.a]),
        Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)

    # test cat with lens infer
    a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
    b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
    ans = Batch.cat([a, b, a])
    assert np.allclose(ans.a.a,
                       np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
    assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
    assert ans.a.t.is_empty()

    b12_stack = Batch.stack((b1, b2))
    assert isinstance(b12_stack.a.d.e, np.ndarray)
    assert b12_stack.a.d.e.ndim == 2

    # test cat with incompatible keys
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
github thu-ml / tianshou / test / base / test_batch.py View on Github external
b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with reserved keys (values are Batch())
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(),
               b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with all reserved keys (values are Batch())
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(),
               b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=Batch(),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
github thu-ml / tianshou / test / base / test_batch.py View on Github external
b12_cat_in = copy.deepcopy(b1)
    b12_cat_in.cat_(b2)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e)
    assert isinstance(b12_cat_in.a.d.e, np.ndarray)
    assert b12_cat_in.a.d.e.ndim == 1

    a = Batch(a=Batch(a=np.random.randn(3, 4)))
    assert np.allclose(
        np.concatenate([a.a.a, a.a.a]),
        Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a)

    # test cat with lens infer
    a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4))
    b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4))
    ans = Batch.cat([a, b, a])
    assert np.allclose(ans.a.a,
                       np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a]))
    assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b]))
    assert ans.a.t.is_empty()

    b12_stack = Batch.stack((b1, b2))
    assert isinstance(b12_stack.a.d.e, np.ndarray)
    assert b12_stack.a.d.e.ndim == 2

    # test cat with incompatible keys
    b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
github thu-ml / tianshou / test / base / test_batch.py View on Github external
b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=np.concatenate([b1.a, np.zeros((4, 4))]),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert np.allclose(test.a, ans.a)
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test cat with all reserved keys (values are Batch())
    b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5)))
    b2 = Batch(a=Batch(),
               b=torch.rand(4, 3),
               common=Batch(c=np.random.rand(4, 5)))
    test = Batch.cat([b1, b2])
    ans = Batch(a=Batch(),
                b=torch.cat([torch.zeros(3, 3), b2.b]),
                common=Batch(c=np.concatenate([b1.common.c, b2.common.c])))
    assert ans.a.is_empty()
    assert torch.allclose(test.b, ans.b)
    assert np.allclose(test.common.c, ans.common.c)

    # test stack with compatible keys
    b3 = Batch(a=np.zeros((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[1], [2]]))
    b4 = Batch(a=np.ones((3, 4)),
               b=torch.ones((2, 5)),
               c=Batch(d=[[0], [3]]))
    b34_stack = Batch.stack((b3, b4), axis=1)
    assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1))