How to use the xenonpy.datatools.Splitter function in xenonpy

To help you get started, we’ve selected a few xenonpy examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_cv_1(data):
    sp = Splitter(10, test_size=0, k_fold=5, random_state=123456)
    with pytest.raises(RuntimeError, match='split action is illegal because `test_size` is none'):
        sp.split()

    tmp = []
    for i, (x, x_) in enumerate(sp.cv()):
        assert x.size == 8
        assert x_.size == 2
        assert isinstance(x, np.ndarray)
        assert isinstance(x_, np.ndarray)
        tmp.append(x_)
    assert i == 4
    tmp = np.concatenate(tmp)
    assert not np.array_equal(tmp, data[0])
    tmp = np.sort(tmp)
    assert np.array_equal(tmp, data[0])
github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_split_2(data):
    sp = Splitter(10, test_size=0.1)
    with pytest.raises(ValueError, match='parameters  must have size 10 for dim 0'):
        sp.split(data[1][1:])
    _, test = sp.split(data[1])
    assert test[0] in data[1]
github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_roll_1():
    sp = Splitter(10, test_size=0.3, random_state=123456)
    train, test = sp.split()
    assert train.size == 7
    assert test.size == 3
    train_, test_ = sp.split()
    assert train_.size == 7
    assert test_.size == 3
    assert np.array_equal(train_, train)
    assert np.array_equal(test_, test)

    sp.roll(random_state=123456)
    train_, test_ = sp.split()
    assert train_.size == 7
    assert test_.size == 3
    assert np.array_equal(train_, train)
    assert np.array_equal(test_, test)
github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_split_1(data):
    sp = Splitter(10)
    with pytest.raises(RuntimeError, match='parameter  must be set'):
        for _ in sp.cv():
            pass

    assert sp.size == 10
    train, test = sp.split()
    assert train.size == 8
    assert test.size == 2

    train, test = sp.split(data[0])
    for d in train:
        assert d in data[0]
    for d in test:
        assert d in data[0]
github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_split_4(data):
    sp = Splitter(10)
    x, x_, y, y_, z, z_ = sp.split(data[1], data[2], data[3])
    assert isinstance(x, np.ndarray)
    assert isinstance(x_, np.ndarray)
    assert isinstance(y, pd.DataFrame)
    assert isinstance(y_, pd.DataFrame)
    assert isinstance(z, pd.Series)
    assert isinstance(z_, pd.Series)
github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_cv_4(data):
    sp = Splitter(10, test_size=0, k_fold=5, random_state=123456)
    tmp = []
    for _, x_ in sp.cv():
        tmp.append(x_)
    tmp = np.concatenate(tmp)

    tmp_ = []
    for _, x_ in sp.cv():
        tmp_.append(x_)
    tmp_ = np.concatenate(tmp_)
    assert np.array_equal(tmp, tmp_)

    tmp_ = []
    sp.roll()
    for _, x_ in sp.cv():
        tmp_.append(x_)
    tmp_ = np.concatenate(tmp_)
github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_cv_2(data):
    sp = Splitter(10, test_size=0.2, k_fold=4)
    tmp = []
    tmp_x_ = []
    for x, x_, _x_ in sp.cv():
        assert x.size == 6
        assert x_.size == 2
        assert _x_.size == 2
        assert isinstance(x, np.ndarray)
        assert isinstance(x_, np.ndarray)
        assert isinstance(_x_, np.ndarray)
        tmp_x_.append(_x_)
        tmp.append(x_)
    assert np.array_equal(tmp_x_[0], tmp_x_[1])
    assert np.array_equal(tmp_x_[0], tmp_x_[2])
    assert np.array_equal(tmp_x_[0], tmp_x_[3])
    tmp = np.concatenate(tmp)
    assert tmp.size == 8
github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_init_1():
    with pytest.raises(RuntimeError, match=' can be none only if  is not none'):
        Splitter(10, test_size=0, k_fold=None)
github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_cv_5(data):
    sp = Splitter(10, test_size=0, k_fold=5)
    for _, x_, _, y_, _, z_ in sp.cv(data[1], data[2], data[3]):
        assert isinstance(x_, np.ndarray)
        assert isinstance(y_, pd.DataFrame)
        assert isinstance(z_, pd.Series)
        assert x_.size == 20
        assert y_.size == 20
        assert z_.size == 2
github yoshida-lab / XenonPy / tests / datatools / test_data_select.py View on Github external
def test_cv_3(data):
    np.random.seed(123456)
    sp = Splitter(10, test_size=0, k_fold=data[4])
    tmp = []
    for x, x_ in sp.cv():
        assert isinstance(x, np.ndarray)
        assert isinstance(x_, np.ndarray)
        assert x.size + x_.size == 10
        tmp.append(x_)
    sizes = np.sort([x.size for x in tmp])
    assert np.array_equal(sizes, [2, 4, 4])
    tmp = np.concatenate(tmp)
    tmp = np.sort(tmp)
    assert np.array_equal(tmp, data[0])