Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from_idx : bool
If True, will skip conversion to internal IDs. (default: False).
Returns
-------
scores_predict : ndarray, shape [n]
The predicted scores for input triples X.
"""
if not self.is_fitted:
msg = 'Model has not been fitted.'
logger.error(msg)
raise RuntimeError(msg)
# adapt the data with numpy adapter for internal use
dataset_handle = NumpyDatasetAdapter()
dataset_handle.use_mappings(self.rel_to_idx, self.ent_to_idx)
dataset_handle.set_data(X, "test", mapped_status=from_idx)
self.eval_dataset_handle = dataset_handle
# build tf graph for predictions
if self.sess_predict is None:
tf.reset_default_graph()
self.rnd = check_random_state(self.seed)
tf.random.set_random_seed(self.seed)
# load the parameters
self._load_model_from_trained_params()
# build the eval graph
self._initialize_eval_graph()
sess = tf.Session()
batches_count: int
Number of batches to complete one epoch of the Platt scaling training.
Returns
-------
scores_pos: tf.Tensor
Tensor with positive scores.
scores_neg: tf.Tensor
Tensor with negative scores (generated by the corruptions).
dataset_handle: NumpyDatasetAdapter
Dataset handle (only used for clean-up).
"""
dataset_handle = NumpyDatasetAdapter()
dataset_handle.use_mappings(self.rel_to_idx, self.ent_to_idx)
dataset_handle.set_data(X_pos, "pos")
gen_fn = partial(dataset_handle.get_next_batch, batches_count=batches_count, dataset_type="pos")
dataset = tf.data.Dataset.from_generator(gen_fn,
output_types=tf.int32,
output_shapes=(None, 3))
dataset = dataset.repeat().prefetch(1)
dataset_iter = tf.data.make_one_shot_iterator(dataset)
x_pos_tf = dataset_iter.get_next()
e_s, e_p, e_o = self._lookup_embeddings(x_pos_tf)
scores_pos = self._fn(e_s, e_p, e_o)
if type(X) is not np.ndarray:
X = np.array(X)
if not self.dealing_with_large_graphs:
if not from_idx:
X = to_idx(X, ent_to_idx=self.ent_to_idx, rel_to_idx=self.rel_to_idx)
x_tf = tf.Variable(X, dtype=tf.int32, trainable=False)
e_s, e_p, e_o = self._lookup_embeddings(x_tf)
scores = self._fn(e_s, e_p, e_o)
with tf.Session(config=self.tf_config) as sess:
sess.run(tf.global_variables_initializer())
return sess.run(scores)
else:
dataset_handle = NumpyDatasetAdapter()
dataset_handle.use_mappings(self.rel_to_idx, self.ent_to_idx)
dataset_handle.set_data(X, "test", mapped_status=from_idx)
self.eval_dataset_handle = dataset_handle
# build tf graph for predictions
self.rnd = check_random_state(self.seed)
tf.random.set_random_seed(self.seed)
# load the parameters
# build the eval graph
self._initialize_eval_graph()
with tf.Session(config=self.tf_config) as sess:
sess.run(tf.tables_initializer())
sess.run(tf.global_variables_initializer())
- **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- **check_interval'**: int : Early stopping interval after burn-in (default:10).
- **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- **'corruption_entities'**: List of entities to be used for corruptions. If 'all',
it uses all entities (default: 'all')
- **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o' (default)
Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
"""
self.train_dataset_handle = None
# try-except block is mainly to handle clean up in case of exception or manual stop in jupyter notebook
try:
if isinstance(X, np.ndarray):
# Adapt the numpy data in the internal format - to generalize
self.train_dataset_handle = NumpyDatasetAdapter()
self.train_dataset_handle.set_data(X, "train")
elif isinstance(X, AmpligraphDatasetAdapter):
self.train_dataset_handle = X
else:
msg = 'Invalid type for input X. Expected ndarray/AmpligraphDataset object, got {}'.format(type(X))
logger.error(msg)
raise ValueError(msg)
# create internal IDs mappings
self.rel_to_idx, self.ent_to_idx = self.train_dataset_handle.generate_mappings()
prefetch_batches = 1
if len(self.ent_to_idx) > ENTITY_THRESHOLD:
self.dealing_with_large_graphs = True
logger.warning('Your graph has a large number of distinct entities. '
>>> use_default_protocol=False)
>>> ranks
array([ 1, 582, 543, 6, 31])
>>> mrr_score(ranks)
0.24049691297347323
>>> hits_at_n_score(ranks, n=10)
0.4
"""
dataset_handle = None
# try-except block is mainly to handle clean up in case of exception or manual stop in jupyter notebook
try:
logger.debug('Evaluating the performance of the embedding model.')
if isinstance(X, np.ndarray):
X_test = filter_unseen_entities(X, model, verbose=verbose, strict=strict)
dataset_handle = NumpyDatasetAdapter()
dataset_handle.use_mappings(model.rel_to_idx, model.ent_to_idx)
dataset_handle.set_data(X_test, "test")
elif isinstance(X, AmpligraphDatasetAdapter):
dataset_handle = X
else:
msg = "X must be either a numpy array or an AmpligraphDatasetAdapter."
logger.error(msg)
raise ValueError(msg)
if filter_triples is not None:
if isinstance(filter_triples, np.ndarray):
logger.debug('Getting filtered triples.')
filter_triples = filter_unseen_entities(filter_triples, model, verbose=verbose, strict=strict)
dataset_handle.set_filter(filter_triples)
model.set_filter_for_eval()
elif isinstance(X, AmpligraphDatasetAdapter):
- **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100).
- **check_interval'**: int : Early stopping interval after burn-in (default:10).
- **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3)
- **'corruption_entities'**: List of entities to be used for corruptions. If 'all',
it uses all entities (default: 'all')
- **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o' (default)
Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}``
"""
self.train_dataset_handle = None
# try-except block is mainly to handle clean up in case of exception or manual stop in jupyter notebook
try:
if isinstance(X, np.ndarray):
# Adapt the numpy data in the internal format - to generalize
self.train_dataset_handle = NumpyDatasetAdapter()
self.train_dataset_handle.set_data(X, "train")
elif isinstance(X, AmpligraphDatasetAdapter):
self.train_dataset_handle = X
else:
msg = 'Invalid type for input X. Expected ndarray/AmpligraphDataset object, got {}'.format(type(X))
logger.error(msg)
raise ValueError(msg)
# create internal IDs mappings
self.rel_to_idx, self.ent_to_idx = self.train_dataset_handle.generate_mappings()
prefetch_batches = 1
if len(self.ent_to_idx) > ENTITY_THRESHOLD:
self.dealing_with_large_graphs = True
prefetch_batches = 0