How to use the ampligraph.evaluation.generate_corruptions_for_eval function in ampligraph

To help you get started, we’ve selected a few ampligraph examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github Accenture / AmpliGraph / tests / ampligraph / evaluation / test_protocol.py View on Github external
def test_generate_corruptions_for_eval_filtered():
    x = np.array([0, 0, 1])
    idx_entities = np.array([0, 1, 2, 3])
    filter_triples = np.array(([1, 0, 1], [2, 0, 1]))

    x_n_actual = generate_corruptions_for_eval(x, idx_entities=idx_entities, filter=filter_triples)
    x_n_expected = np.array([[3, 0, 1],
                             [0, 0, 0],
                             [0, 0, 2],
                             [0, 0, 3]])
    np.testing.assert_array_equal(np.sort(x_n_actual, axis=0), np.sort(x_n_expected, axis=0))
github Accenture / AmpliGraph / tests / ampligraph / evaluation / test_protocol.py View on Github external
def test_generate_corruptions_for_eval():
    X = np.array([['a', 'x', 'b'],
                  ['c', 'x', 'd'],
                  ['e', 'x', 'f'],
                  ['b', 'y', 'h'],
                  ['a', 'y', 'l']])

    rel_to_idx, ent_to_idx = create_mappings(X)
    X = to_idx(X, ent_to_idx=ent_to_idx, rel_to_idx=rel_to_idx)

    with tf.Session() as sess:
        all_ent = tf.constant(list(ent_to_idx.values()), dtype=tf.int64)
        x = tf.constant(np.array([X[0]]), dtype=tf.int64)
        x_n_actual = sess.run(generate_corruptions_for_eval(x, all_ent))
        x_n_expected = np.array([[0, 0, 0],
                                 [0, 0, 1],
                                 [0, 0, 2],
                                 [0, 0, 3],
                                 [0, 0, 4],
                                 [0, 0, 5],
                                 [0, 0, 6],
                                 [0, 0, 7],
                                 [0, 0, 1],
                                 [1, 0, 1],
                                 [2, 0, 1],
                                 [3, 0, 1],
                                 [4, 0, 1],
                                 [5, 0, 1],
                                 [6, 0, 1],
                                 [7, 0, 1]])
github Accenture / AmpliGraph / ampligraph / latent_features / models / EmbeddingModel.py View on Github external
if corruption_entities == 'all':
                corruption_entities = all_entities_np
            elif isinstance(corruption_entities, np.ndarray):
                corruption_entities = corruption_entities
            else:
                msg = 'Invalid type for corruption entities.'
                logger.error(msg)
                raise ValueError(msg)

            # Entities that must be used while generating corruptions
            self.corruption_entities_tf = tf.constant(corruption_entities, dtype=tf.int32)

            corrupt_side = self.eval_config.get('corrupt_side', constants.DEFAULT_CORRUPT_SIDE_EVAL)
            # Generate corruptions
            self.out_corr = generate_corruptions_for_eval(self.X_test_tf,
                                                          self.corruption_entities_tf,
                                                          corrupt_side)

            # Compute scores for negatives
            e_s, e_p, e_o = self._lookup_embeddings(self.out_corr)
            self.scores_predict = self._fn(e_s, e_p, e_o)

            # Compute scores for positive
            e_s, e_p, e_o = self._lookup_embeddings(self.X_test_tf)
            self.score_positive = tf.squeeze(self._fn(e_s, e_p, e_o))

            use_default_protocol = self.eval_config.get('default_protocol', constants.DEFAULT_PROTOCOL_EVAL)

            if use_default_protocol:
                obj_corruption_scores = tf.slice(self.scores_predict,
                                                 [0],
github Accenture / AmpliGraph / ampligraph / latent_features / models.py View on Github external
if corruption_entities == 'all':
                corruption_entities = all_entities_np
            elif isinstance(corruption_entities, np.ndarray):
                corruption_entities = corruption_entities
            else:
                msg = 'Invalid type for corruption entities.'
                logger.error(msg)
                raise ValueError(msg)

            # Entities that must be used while generating corruptions
            self.corruption_entities_tf = tf.constant(corruption_entities, dtype=tf.int32)

            corrupt_side = self.eval_config.get('corrupt_side', DEFAULT_CORRUPT_SIDE_EVAL)
            # Generate corruptions
            self.out_corr = generate_corruptions_for_eval(self.X_test_tf,
                                                          self.corruption_entities_tf,
                                                          corrupt_side)

            # Compute scores for negatives
            e_s, e_p, e_o = self._lookup_embeddings(self.out_corr)
            self.scores_predict = self._fn(e_s, e_p, e_o)

            # Compute scores for positive
            e_s, e_p, e_o = self._lookup_embeddings(self.X_test_tf)
            self.score_positive = tf.squeeze(self._fn(e_s, e_p, e_o))

            use_default_protocol = self.eval_config.get('default_protocol', DEFAULT_PROTOCOL_EVAL)

            if use_default_protocol:
                obj_corruption_scores = tf.slice(self.scores_predict,
                                                 [0],