Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_memoryview_float_notrans(A, B, a_rows, a_cols, out_cols):
A, B, C = _reshape_for_gemm(A, B, a_rows, a_cols, out_cols, dtype='float32')
assume(A is not None)
assume(B is not None)
assume(C is not None)
assume(A.size >= 1)
assume(B.size >= 1)
assume(C.size >= 1)
gemm(A, B, out=C)
numpy_result = A.dot(B)
assert_allclose(numpy_result, C, atol=1e-3, rtol=1e-3)
def test_memoryview_double_notrans(A, B, a_rows, a_cols, out_cols):
A, B, C = _reshape_for_gemm(A, B, a_rows, a_cols, out_cols, 'float64')
assume(A is not None)
assume(B is not None)
assume(C is not None)
assume(A.size >= 1)
assume(B.size >= 1)
assume(C.size >= 1)
gemm(A, B, out=C)
numpy_result = A.dot(B)
assert_allclose(numpy_result, C, atol=1e-3, rtol=1e-3)
def blis_gemm(X, W, n=1000):
nO, nI = W.shape
batch_size = X.shape[0]
total = 0.0
y = numpy.zeros((batch_size, nO), dtype="f")
for i in range(n):
gemm(X, W, out=y)
total += y.sum()
y.fill(0.0)
print("Total:", total)