How to use the ivis.data.triplet_generators.KnnTripletGenerator function in ivis

To help you get started, we’ve selected a few ivis examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github beringresearch / ivis / tests / data / test_triplet_generator.py View on Github external
def test_AnnoyTripletGenerator():
    neighbour_list = np.load('tests/data/test_knn_k4.npy')

    iris = datasets.load_iris()
    X = iris.data
    batch_size = 32

    data_generator = KnnTripletGenerator(X, neighbour_list,
                                         batch_size=batch_size)

    # Run generator thorugh one iteration of dataset and into the next
    for i in range((X.shape[0] // batch_size) + 1):
        batch = data_generator.__getitem__(i)

        # Check that everything is the expected shape
        assert isinstance(batch, tuple)
        assert len(batch) == 2

        assert len(batch[0]) == 3
        assert len(batch[1]) <= batch_size
        assert batch[0][0].shape[-1] == X.shape[-1]
github beringresearch / ivis / tests / data / test_triplet_generator.py View on Github external
def test_KnnTripletGenerator():
    neighbour_list = np.load('tests/data/test_knn_k4.npy')

    iris = datasets.load_iris()
    X = iris.data
    batch_size = 32

    data_generator = KnnTripletGenerator(X, neighbour_list,
                                         batch_size=batch_size)

    # Run generator thorugh one iteration of dataset and into the next
    for i in range((X.shape[0] // batch_size) + 1):
        batch = data_generator.__getitem__(i)

        # Check that everything is the expected shape
        assert isinstance(batch, tuple)
        assert len(batch) == 2

        assert len(batch[0]) == 3
        assert len(batch[1]) <= batch_size
        assert batch[0][0].shape[-1] == X.shape[-1]
github beringresearch / ivis / ivis / data / triplet_generators.py View on Github external
raise Exception('''k value greater than or equal to (num_rows - 1)
                        (k={}, rows={}). Lower k to a smaller
                        value.'''.format(k, X.shape[0]))
    if batch_size > X.shape[0]:
        raise Exception('''batch_size value larger than num_rows in dataset
                        (batch_size={}, rows={}). Lower batch_size to a
                        smaller value.'''.format(batch_size, X.shape[0]))

    if Y is None:
        if precompute:
            if verbose > 0:
                print('Extracting KNN from index')

            neighbour_matrix = extract_knn(X, index_path, k=k,
                                           search_k=search_k, verbose=verbose)
            return KnnTripletGenerator(X, neighbour_matrix,
                                       batch_size=batch_size)
        else:
            index = AnnoyIndex(X.shape[1], metric='angular')
            index.load(index_path)
            return AnnoyTripletGenerator(X, index, k=k,
                                         batch_size=batch_size,
                                         search_k=search_k)
    else:
        if precompute:
            if verbose > 0:
                print('Extracting KNN from index')

            neighbour_matrix = extract_knn(X, index_path, k=k,
                                           search_k=search_k, verbose=verbose)
            return LabeledKnnTripletGenerator(X, Y, neighbour_matrix,
                                              batch_size=batch_size)