Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# Creation of the trainer
with logger.section("Create trainer"):
optimizer = tf.train.AdamOptimizer(learning_rate=args.lr)
train_iterator = train_dataset.make_initializable_iterator()
data, target = train_iterator.get_next()
train_loss = loss(model, data, target)
train_op = optimizer.minimize(train_loss)
test_iterator = test_dataset.make_initializable_iterator()
data, target = test_iterator.get_next()
test_loss = loss(model, data, target)
test_accuracy = accuracy(model, data, target)
logger.add_indicator("train_loss", queue_limit=10, is_print=True)
logger.add_indicator("test_loss", is_histogram=False, is_print=True)
logger.add_indicator("accuracy", is_histogram=False, is_print=True)
#
batches = len(x_train) // args.batch_size
with tf.Session() as session:
EXPERIMENT.start_train(session)
# Loop through the monitored iterator
for epoch in logger.loop(range(0, args.epochs)):
# Delayed keyboard interrupt handling to use
# keyboard interrupts to end the loop.
# This will capture interrupts and finish
# the loop at the end of processing the iteration;
# i.e. the loop won't stop in the middle of an epoch.
try:
with logger.delayed_keyboard_interrupt():
def compare(a):
allclose(B.transpose(a), to_np(a).T)
def test_uprank():
allclose(uprank(0), [[0]])
allclose(uprank(np.array([0])), [[0]])
allclose(uprank(np.array([[0]])), [[0]])
assert type(uprank(Component('test')(0))) == Component('test')
k = OneKernel()
assert B.shape(k(0, 0)) == (1, 1)
assert B.shape(k(0, np.ones(5))) == (1, 5)
assert B.shape(k(0, np.ones((5, 2)))) == (1, 5)
assert B.shape(k(np.ones(5), 0)) == (5, 1)
assert B.shape(k(np.ones(5), np.ones(5))) == (5, 5)
assert B.shape(k(np.ones(5), np.ones((5, 2)))) == (5, 5)
assert B.shape(k(np.ones((5, 2)), 0)) == (5, 1)
assert B.shape(k(np.ones((5, 2)), np.ones(5))) == (5, 5)
assert B.shape(k(np.ones((5, 2)), np.ones((5, 2)))) == (5, 5)
with pytest.raises(ValueError):
k(0, np.ones((5, 2, 1)))
with pytest.raises(ValueError):
k(np.ones((5, 2, 1)))
m = OneMean()
def test_uprank():
allclose(uprank(0), [[0]])
allclose(uprank(np.array([0])), [[0]])
allclose(uprank(np.array([[0]])), [[0]])
assert type(uprank(Component('test')(0))) == Component('test')
k = OneKernel()
assert B.shape(k(0, 0)) == (1, 1)
assert B.shape(k(0, np.ones(5))) == (1, 5)
assert B.shape(k(0, np.ones((5, 2)))) == (1, 5)
assert B.shape(k(np.ones(5), 0)) == (5, 1)
assert B.shape(k(np.ones(5), np.ones(5))) == (5, 5)
assert B.shape(k(np.ones(5), np.ones((5, 2)))) == (5, 5)
assert B.shape(k(np.ones((5, 2)), 0)) == (5, 1)
assert B.shape(k(np.ones((5, 2)), np.ones(5))) == (5, 5)
assert B.shape(k(np.ones((5, 2)), np.ones((5, 2)))) == (5, 5)
with pytest.raises(ValueError):
k(0, np.ones((5, 2, 1)))
with pytest.raises(ValueError):
k(np.ones((5, 2, 1)))
m = OneMean()
assert B.shape(m(0)) == (1, 1)
assert B.shape(m(np.ones(5))) == (5, 1)
assert B.shape(m(np.ones((5, 2)))) == (5, 1)
def test_dtype():
# Test `Dense`.
assert B.dtype(Dense(np.array([[1]]))) == np.int64
assert B.dtype(Dense(np.array([[1.0]]))) == np.float64
# Test `Diagonal`.
diag_int = Diagonal(np.array([1]))
diag_float = Diagonal(np.array([1.0]))
assert B.dtype(diag_int) == np.int64
assert B.dtype(diag_float) == np.float64
# Test `LowRank`.
lr_int = LowRank(left=np.array([[1]]),
right=np.array([[2]]),
middle=np.array([[3]]))
lr_float = LowRank(left=np.array([[1.0]]),
right=np.array([[2.0]]),
middle=np.array([[3.0]]))
assert B.dtype(lr_int) == np.int64
assert B.dtype(lr_float) == np.float64
# Test `Constant`.
assert B.dtype(Constant(1, rows=1)) == int
assert B.dtype(Constant(1.0, rows=1)) == float
# Test `Woodbury`.
assert B.dtype(diag_int) == np.int64
assert B.dtype(diag_float) == np.float64
# Test `LowRank`.
lr_int = LowRank(left=np.array([[1]]),
right=np.array([[2]]),
middle=np.array([[3]]))
lr_float = LowRank(left=np.array([[1.0]]),
right=np.array([[2.0]]),
middle=np.array([[3.0]]))
assert B.dtype(lr_int) == np.int64
assert B.dtype(lr_float) == np.float64
# Test `Constant`.
assert B.dtype(Constant(1, rows=1)) == int
assert B.dtype(Constant(1.0, rows=1)) == float
# Test `Woodbury`.
assert B.dtype(Woodbury(diag_int, lr_int)) == np.int64
assert B.dtype(Woodbury(diag_float, lr_float)) == np.float64
def test_inverse_and_logdet():
# Test `Dense`.
a = np.random.randn(3, 3)
a = Dense(a.dot(a.T))
allclose(B.matmul(a, B.inverse(a)), np.eye(3))
allclose(B.matmul(B.inverse(a), a), np.eye(3))
allclose(B.logdet(a), np.log(np.linalg.det(to_np(a))))
# Test `Diagonal`.
d = Diagonal(np.array([1, 2, 3]))
allclose(B.matmul(d, B.inverse(d)), np.eye(3))
allclose(B.matmul(B.inverse(d), d), np.eye(3))
allclose(B.logdet(d), np.log(np.linalg.det(to_np(d))))
assert B.shape(B.inverse(Diagonal(np.array([1, 2]),
rows=2, cols=4))) == (4, 2)
# Test `Woodbury`.
a = np.random.randn(3, 2)
b = np.random.randn(2, 2) + 1e-2 * np.eye(2)
wb = d + LowRank(left=a, middle=b.dot(b.T))
for _ in range(4):
allclose(B.matmul(wb, B.inverse(wb)), np.eye(3))
allclose(B.matmul(B.inverse(wb), wb), np.eye(3))
allclose(B.logdet(wb), np.log(np.linalg.det(to_np(wb))))
wb = B.inverse(wb)
# Test `LowRank`.
def test_sample():
a = np.random.randn(3, 3)
a = Dense(a.dot(a.T))
b = np.random.randn(2, 2)
wb = Diagonal(B.diag(a)) + LowRank(left=np.random.randn(3, 2),
middle=b.dot(b.T))
# Test `Dense` and `Woodbury`.
num_samps = 500000
for cov in [a, wb]:
samps = B.sample(cov, num_samps)
cov_emp = B.matmul(samps, samps, tr_b=True) / num_samps
assert np.mean(np.abs(to_np(cov_emp) - to_np(cov))) <= 5e-2
def compare(a, b):
return np.allclose(to_np(B.matmul(a, b)),
B.matmul(to_np(a), to_np(b)))
def test_diag():
# Test `Dense`.
a = np.random.randn(5, 3)
allclose(B.diag(Dense(a)), np.diag(a))
# Test `Diagonal`.
allclose(B.diag(Diagonal(np.array([1, 2, 3]))), [1, 2, 3])
allclose(B.diag(Diagonal(np.array([1, 2, 3]), 2)), [1, 2])
allclose(B.diag(Diagonal(np.array([1, 2, 3]), 4)), [1, 2, 3, 0])
# Test `LowRank`.
b = np.random.randn(10, 3)
allclose(B.diag(LowRank(left=a, right=a)), np.diag(a.dot(a.T)))
allclose(B.diag(LowRank(left=a, right=b)), np.diag(a.dot(b.T)))
allclose(B.diag(LowRank(left=b, right=b)), np.diag(b.dot(b.T)))
# Test `Constant`.
allclose(B.diag(Constant(1, rows=3, cols=5)), np.ones(3))
# Test `Woodbury`.