Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_generate_corruptions_for_eval_filtered():
x = np.array([0, 0, 1])
idx_entities = np.array([0, 1, 2, 3])
filter_triples = np.array(([1, 0, 1], [2, 0, 1]))
x_n_actual = generate_corruptions_for_eval(x, idx_entities=idx_entities, filter=filter_triples)
x_n_expected = np.array([[3, 0, 1],
[0, 0, 0],
[0, 0, 2],
[0, 0, 3]])
np.testing.assert_array_equal(np.sort(x_n_actual, axis=0), np.sort(x_n_expected, axis=0))
def test_generate_corruptions_for_eval():
X = np.array([['a', 'x', 'b'],
['c', 'x', 'd'],
['e', 'x', 'f'],
['b', 'y', 'h'],
['a', 'y', 'l']])
rel_to_idx, ent_to_idx = create_mappings(X)
X = to_idx(X, ent_to_idx=ent_to_idx, rel_to_idx=rel_to_idx)
with tf.Session() as sess:
all_ent = tf.constant(list(ent_to_idx.values()), dtype=tf.int64)
x = tf.constant(np.array([X[0]]), dtype=tf.int64)
x_n_actual = sess.run(generate_corruptions_for_eval(x, all_ent))
x_n_expected = np.array([[0, 0, 0],
[0, 0, 1],
[0, 0, 2],
[0, 0, 3],
[0, 0, 4],
[0, 0, 5],
[0, 0, 6],
[0, 0, 7],
[0, 0, 1],
[1, 0, 1],
[2, 0, 1],
[3, 0, 1],
[4, 0, 1],
[5, 0, 1],
[6, 0, 1],
[7, 0, 1]])
if corruption_entities == 'all':
corruption_entities = all_entities_np
elif isinstance(corruption_entities, np.ndarray):
corruption_entities = corruption_entities
else:
msg = 'Invalid type for corruption entities.'
logger.error(msg)
raise ValueError(msg)
# Entities that must be used while generating corruptions
self.corruption_entities_tf = tf.constant(corruption_entities, dtype=tf.int32)
corrupt_side = self.eval_config.get('corrupt_side', constants.DEFAULT_CORRUPT_SIDE_EVAL)
# Generate corruptions
self.out_corr = generate_corruptions_for_eval(self.X_test_tf,
self.corruption_entities_tf,
corrupt_side)
# Compute scores for negatives
e_s, e_p, e_o = self._lookup_embeddings(self.out_corr)
self.scores_predict = self._fn(e_s, e_p, e_o)
# Compute scores for positive
e_s, e_p, e_o = self._lookup_embeddings(self.X_test_tf)
self.score_positive = tf.squeeze(self._fn(e_s, e_p, e_o))
use_default_protocol = self.eval_config.get('default_protocol', constants.DEFAULT_PROTOCOL_EVAL)
if use_default_protocol:
obj_corruption_scores = tf.slice(self.scores_predict,
[0],
if corruption_entities == 'all':
corruption_entities = all_entities_np
elif isinstance(corruption_entities, np.ndarray):
corruption_entities = corruption_entities
else:
msg = 'Invalid type for corruption entities.'
logger.error(msg)
raise ValueError(msg)
# Entities that must be used while generating corruptions
self.corruption_entities_tf = tf.constant(corruption_entities, dtype=tf.int32)
corrupt_side = self.eval_config.get('corrupt_side', DEFAULT_CORRUPT_SIDE_EVAL)
# Generate corruptions
self.out_corr = generate_corruptions_for_eval(self.X_test_tf,
self.corruption_entities_tf,
corrupt_side)
# Compute scores for negatives
e_s, e_p, e_o = self._lookup_embeddings(self.out_corr)
self.scores_predict = self._fn(e_s, e_p, e_o)
# Compute scores for positive
e_s, e_p, e_o = self._lookup_embeddings(self.X_test_tf)
self.score_positive = tf.squeeze(self._fn(e_s, e_p, e_o))
use_default_protocol = self.eval_config.get('default_protocol', DEFAULT_PROTOCOL_EVAL)
if use_default_protocol:
obj_corruption_scores = tf.slice(self.scores_predict,
[0],