How to use the ivis.nn.losses.triplet_loss function in ivis

To help you get started, we’ve selected a few ivis examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github beringresearch / ivis / tests / nn / test_losses.py View on Github external
def test_loss_function_call():
    margin = 2

    loss_dict = get_loss_functions(margin=margin)
    
    for loss_name in loss_dict.keys():
        loss_function = triplet_loss(distance=loss_name, margin=margin)
        assert loss_function.__name__ == loss_name
github beringresearch / ivis / ivis / ivis.py View on Github external
build_annoy_index(X, self.annoy_index_path,
                              ntrees=self.ntrees,
                              build_index_on_disk=self.build_index_on_disk,
                              verbose=self.verbose)

        datagen = generator_from_index(X, Y,
                                       index_path=self.annoy_index_path,
                                       k=self.k,
                                       batch_size=self.batch_size,
                                       search_k=self.search_k,
                                       precompute=self.precompute,
                                       verbose=self.verbose)

        loss_monitor = 'loss'
        try:
            triplet_loss_func = triplet_loss(distance=self.distance,
                                             margin=self.margin)
        except KeyError:
            raise ValueError('Loss function `{}` not implemented.'.format(self.distance))

        if self.model_ is None:
            if type(self.model_def) is str:
                input_size = (X.shape[-1],)
                self.model_, anchor_embedding, _, _ = \
                    triplet_network(base_network(self.model_def, input_size),
                                    embedding_dims=self.embedding_dims)
            else:
                self.model_, anchor_embedding, _, _ = \
                    triplet_network(self.model_def,
                                    embedding_dims=self.embedding_dims)

            if Y is None:
github beringresearch / ivis / ivis / ivis.py View on Github external
Parameters
        ----------
        folder_path : string
            Path to serialised model files and metadata

        Returns
        -------
        returns an ivis instance
        """

        ivis_config = json.load(open(os.path.join(folder_path,
                                                  'ivis_params.json'), 'r'))
        self.__dict__ = ivis_config

        loss_function = triplet_loss(self.distance, self.margin)
        self.model_ = load_model(os.path.join(folder_path, 'ivis_model.h5'),
                                 custom_objects={'tf': tf,
                                                 loss_function.__name__: loss_function })
        self.encoder = self.model_.layers[3]
        self.encoder._make_predict_function()

        # If a supervised model exists, load it
        supervised_path = os.path.join(folder_path, 'supervised_model.h5')
        if os.path.exists(supervised_path):
            self.supervised_model_ = load_model(supervised_path)
        return self