Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""
import time, copy
if x_dev == None or y_dev == None:
from sklearn.cross_validation import train_test_split
x_train, x_dev, y_train, y_dev = train_test_split(x_train, y_train,
test_size=split_ratio, random_state=42)
if method == 'sgd':
train_fn = self.get_SGD_trainer()
elif method == 'adagrad':
train_fn = self.get_adagrad_trainer()
elif method == 'adadelta':
train_fn = self.get_adadelta_trainer()
elif method == 'rmsprop':
train_fn = self.get_rmsprop_trainer(with_step_adapt=True,
nesterov=False)
train_set_iterator = DatasetMiniBatchIterator(x_train, y_train)
dev_set_iterator = DatasetMiniBatchIterator(x_dev, y_dev)
train_scoref = self.score_classif(train_set_iterator)
dev_scoref = self.score_classif(dev_set_iterator)
best_dev_loss = numpy.inf
epoch = 0
# TODO early stopping (not just cross val, also stop training)
if plot:
verbose = True
self._costs = []
self._train_errors = []
self._dev_errors = []
self._updates = []
init_lr = INIT_LR
if method == 'rmsprop':
init_lr = 1.E-6 # TODO REMOVE HACK
import time, copy
if x_dev == None or y_dev == None:
from sklearn.cross_validation import train_test_split
x_train, x_dev, y_train, y_dev = train_test_split(x_train, y_train,
test_size=split_ratio, random_state=42)
if method == 'sgd':
train_fn = self.get_SGD_trainer()
elif method == 'adagrad':
train_fn = self.get_adagrad_trainer()
elif method == 'adadelta':
train_fn = self.get_adadelta_trainer()
elif method == 'rmsprop':
train_fn = self.get_rmsprop_trainer(with_step_adapt=True,
nesterov=False)
train_set_iterator = DatasetMiniBatchIterator(x_train, y_train)
dev_set_iterator = DatasetMiniBatchIterator(x_dev, y_dev)
train_scoref = self.score_classif(train_set_iterator)
dev_scoref = self.score_classif(dev_set_iterator)
best_dev_loss = numpy.inf
epoch = 0
# TODO early stopping (not just cross val, also stop training)
if plot:
verbose = True
self._costs = []
self._train_errors = []
self._dev_errors = []
self._updates = []
init_lr = INIT_LR
if method == 'rmsprop':
init_lr = 1.E-6 # TODO REMOVE HACK
n_seen = 0
def score(self, x, y):
""" error rates """
iterator = DatasetMiniBatchIterator(x, y)
scoref = self.score_classif(iterator)
return numpy.mean(scoref())