How to use the ivis.data.triplet_generators.generator_from_index function in ivis

To help you get started, we’ve selected a few ivis examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github beringresearch / ivis / tests / data / test_triplet_generator.py View on Github external
def test_generator_from_index():
        # Test too large k raises exception
        with pytest.raises(Exception):
                generator_from_index(np.zeros(shape=(4, 5)), 
                                     'placeholder_path.index',
                                     k=10,
                                     batch_size=2,
                                     search_k=1,
                                     precompute=False,
                                     verbose=0)
        # Test too large batch_size raises exception
        with pytest.raises(Exception):
                generator_from_index(np.zeros(shape=(4, 5)), 
                                     'placeholder_path.index',
                                     k=2,
                                     batch_size=8,
                                     search_k=1,
                                     precompute=False,
                                     verbose=0)
github beringresearch / ivis / ivis / ivis.py View on Github external
def _fit(self, X, Y=None, shuffle_mode=True):

        if self.annoy_index_path is None:
            self.annoy_index_path = 'annoy.index'
            if self.verbose > 0:
                print('Building KNN index')
            build_annoy_index(X, self.annoy_index_path,
                              ntrees=self.ntrees,
                              build_index_on_disk=self.build_index_on_disk,
                              verbose=self.verbose)

        datagen = generator_from_index(X, Y,
                                       index_path=self.annoy_index_path,
                                       k=self.k,
                                       batch_size=self.batch_size,
                                       search_k=self.search_k,
                                       precompute=self.precompute,
                                       verbose=self.verbose)

        loss_monitor = 'loss'
        try:
            triplet_loss_func = triplet_loss(distance=self.distance,
                                             margin=self.margin)
        except KeyError:
            raise ValueError('Loss function `{}` not implemented.'.format(self.distance))

        if self.model_ is None:
            if type(self.model_def) is str: