Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_top_n_numpy(self):
data = np.array([-10, -0.1, 0.0, 0, 3.4, 1.5])
# ============== Without replace value, descending
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1)
np.testing.assert_equal(modified_data, np.array([3.4]))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2)
np.testing.assert_equal(modified_data, np.array([3.4, 1.5]))
np.testing.assert_equal(idxes, np.array([4, 5]))
# ============== Without replace value, ascending
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1, ascending=True)
np.testing.assert_equal(modified_data, np.array([3.4]))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2, ascending=True)
np.testing.assert_equal(modified_data, np.array([1.5, 3.4]))
np.testing.assert_equal(idxes, np.array([5, 4]))
# FIXME: Not yet handle
# # Top same length
# modified_data, idxes = filter_top_n(data, len(data))
# np.testing.assert_equal(modified_data, np.array([-10, -0.1, 0., 0., 1.5, 3.4]))
def test_top_n_pytorch(self):
data = np.array([-10, -0.1, 0.0, 0, 3.4, 1.5])
data = torch.tensor(data, dtype=torch.float)
# ============== Without replace value, deascending
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
np.testing.assert_equal(modified_data, np.array([3.4], dtype=np.float32))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
np.testing.assert_equal(modified_data, np.array([3.4, 1.5], dtype=np.float32))
np.testing.assert_equal(idxes, np.array([4, 5]))
# ============== Without replace value, ascending
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1, ascending=True)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
np.testing.assert_equal(modified_data, np.array([3.4], dtype=np.float32))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2, ascending=True)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
def test_top_n_pytorch(self):
data = np.array([-10, -0.1, 0.0, 0, 3.4, 1.5])
data = torch.tensor(data, dtype=torch.float)
# ============== Without replace value, deascending
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
np.testing.assert_equal(modified_data, np.array([3.4], dtype=np.float32))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
np.testing.assert_equal(modified_data, np.array([3.4, 1.5], dtype=np.float32))
np.testing.assert_equal(idxes, np.array([4, 5]))
# ============== Without replace value, ascending
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1, ascending=True)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2, ascending=True)
np.testing.assert_equal(modified_data, np.array([1.5, 3.4]))
np.testing.assert_equal(idxes, np.array([5, 4]))
# FIXME: Not yet handle
# # Top same length
# modified_data, idxes = filter_top_n(data, len(data))
# np.testing.assert_equal(modified_data, np.array([-10, -0.1, 0., 0., 1.5, 3.4]))
# np.testing.assert_equal(idxes, np.array([5, 4, 3, 2, 1, 0]))
# FIXME: Not yet handle
# # Top 100
# modified_data = filtering.top_n(data, 100)
# self.assert_lists(modified_data, data)
# ============== With replace value
modified_data, idxes = filtering.filter_top_k(data, 1, replace=-99)
np.testing.assert_equal(modified_data, np.array([-99, -99, -99, -99, 3.4, -99]))
np.testing.assert_equal(idxes, np.array([4]))
modified_data, idxes = filtering.filter_top_k(data, 2, replace=-99)
np.testing.assert_equal(modified_data, np.array([-99, -99, -99, -99, 3.4, 1.5]))
np.testing.assert_equal(idxes, np.array([5, 4]))
def test_top_n_numpy(self):
data = np.array([-10, -0.1, 0.0, 0, 3.4, 1.5])
# ============== Without replace value, descending
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1)
np.testing.assert_equal(modified_data, np.array([3.4]))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2)
np.testing.assert_equal(modified_data, np.array([3.4, 1.5]))
np.testing.assert_equal(idxes, np.array([4, 5]))
# ============== Without replace value, ascending
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1, ascending=True)
np.testing.assert_equal(modified_data, np.array([3.4]))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2, ascending=True)
np.testing.assert_equal(modified_data, np.array([1.5, 3.4]))
np.testing.assert_equal(idxes, np.array([5, 4]))
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1)
np.testing.assert_equal(modified_data, np.array([3.4]))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2)
np.testing.assert_equal(modified_data, np.array([3.4, 1.5]))
np.testing.assert_equal(idxes, np.array([4, 5]))
# ============== Without replace value, ascending
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1, ascending=True)
np.testing.assert_equal(modified_data, np.array([3.4]))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2, ascending=True)
np.testing.assert_equal(modified_data, np.array([1.5, 3.4]))
np.testing.assert_equal(idxes, np.array([5, 4]))
# FIXME: Not yet handle
# # Top same length
# modified_data, idxes = filter_top_n(data, len(data))
# np.testing.assert_equal(modified_data, np.array([-10, -0.1, 0., 0., 1.5, 3.4]))
# np.testing.assert_equal(idxes, np.array([5, 4, 3, 2, 1, 0]))
# FIXME: Not yet handle
# # Top 100
# modified_data = filtering.top_n(data, 100)
# self.assert_lists(modified_data, data)
# ============== With replace value
modified_data, idxes = filtering.filter_top_k(data, 1, replace=-99)
np.testing.assert_equal(modified_data, np.array([-99, -99, -99, -99, 3.4, -99]))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2, ascending=True)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
np.testing.assert_equal(modified_data, np.array([1.5, 3.4], dtype=np.float32))
np.testing.assert_equal(idxes, np.array([5, 4]))
# ============== With replace value
# Top 1
modified_data, idxes = filtering.filter_top_k(data, 1, replace=-99)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
np.testing.assert_equal(modified_data, np.array([-99., -99., -99., -99., 3.4, -99.], dtype=np.float32))
np.testing.assert_equal(idxes, np.array([4]))
# Top 2
modified_data, idxes = filtering.filter_top_k(data, 2, replace=-99)
modified_data = modified_data.data.numpy()
idxes = idxes.data.numpy()
np.testing.assert_equal(modified_data, np.array([-99, -99, -99, -99, 3.4, 1.5], dtype=np.float32))
np.testing.assert_equal(idxes, np.array([4, 5]))
def filtering(self, logits, seed):
top_k = seed['top_k']
top_p = seed['top_p']
if top_k is not None and 0 < top_k < len(logits):
logits, idxes = filtering.filter_top_k(logits, top_k, replace=-float('Inf'))
if top_p is not None and 0 < top_p < 1:
logits, idxes = filtering.nucleus_sampling(logits, top_p)
# If top_p is not None, value will be sorted, so no need to select it again
if top_p is None:
if top_k is None:
idxes = np.arange(len(logits)).tolist()
else:
logits = logits.index_select(0, idxes)
if self.device == 'cuda':
idxes = idxes.cpu()
idxes = idxes.detach().numpy().tolist()
else:
logits = logits[:len(idxes)]
if self.device == 'cuda':
idxes = idxes.cpu()