Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def run_test(name, params_fun):
"""Runs a distributed GPU test."""
# Always call this before using distributed module
xgb.rabit.init()
rank = xgb.rabit.get_rank()
world = xgb.rabit.get_world_size()
# Load file, file will be automatically sharded in distributed mode.
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
params, n_rounds = params_fun(rank)
# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(params, dtrain, n_rounds, watchlist, early_stopping_rounds=2)
# Have each worker save its model
model_name = "test.model.%s.%d" % (name, rank)
dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')
# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
# Specify validations set to watch performance
watchlist = [(dtest,'eval'), (dtrain,'train')]
num_round = 20
# Run training, all the features in training API is available.
# Currently, this script only support calling train once for fault recovery purpose.
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
eval_dmatrices = [ToDMatrix.get_xgb_dmatrix(ctx[t[0].key]) for t in op.evals]
evals = tuple((m, ev[1]) for m, ev in zip(eval_dmatrices, op.evals))
params = op.params
params['nthread'] = ctx.get_ncores() or -1
if op.tracker is None:
# non distributed
local_history = dict()
kwargs = dict() if op.kwargs is None else op.kwargs
bst = train(params, dtrain, evals=evals,
evals_result=local_history, **kwargs)
ctx[op.outputs[0].key] = {'booster': pickle.dumps(bst), 'history': local_history}
else:
# distributed
rabit_args = ctx[op.tracker.key]
rabit.init(rabit_args)
try:
local_history = dict()
bst = train(params, dtrain, evals=evals, evals_result=local_history,
**op.kwargs)
ret = {'booster': pickle.dumps(bst), 'history': local_history}
if rabit.get_rank() != 0:
ret = {}
ctx[op.outputs[0].key] = ret
finally:
rabit.finalize()
s.connect((self.master_host, self.port))
successful_connection = True
self.logger.debug("Successfully connected to RabitTracker.")
except OSError:
self.logger.info("Failed to connect to RabitTracker on attempt {}".format(attempt))
attempt += 1
self.logger.info("Sleeping for {} sec before retrying".format(self.connect_retry_timeout))
time.sleep(self.connect_retry_timeout)
if not successful_connection:
self.logger.error("Failed to connect to Rabit Tracker after %s attempts", self.max_connect_attempts)
raise Exception("Failed to connect to Rabit Tracker")
else:
self.logger.info("Connected to RabitTracker.")
rabit.init(['DMLC_NUM_WORKER={}'.format(self.n_workers).encode(),
'DMLC_TRACKER_URI={}'.format(self.master_host).encode(),
'DMLC_TRACKER_PORT={}'.format(self.port).encode()])
# We can check that the Rabit instance has successfully connected to the
# server by getting the rank of the server (e.g. its position in the ring).
# This should be unique for each instance.
self.logger.debug("Rabit started - Rank {}".format(rabit.get_rank()))
self.logger.debug("Executing user code")
# We can now run user-code. Since XGBoost runs in the same process space
# it will use the same instance of Rabit that we have configured. It has
# a number of checks throughout the learning process to see if it is running
# in distributed mode by calling Rabit APIs. If it is it will do the
# synchronization automatically.
#
# Hence we can now execute any XGBoost specific training code and it
def callback(env):
"""internal function"""
if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0:
return
i = env.iteration
if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
msg = '\t'.join([format_metric(x, show_stdv) for x in env.evaluation_result_list])
rabit.tracker_print('[%d]\t%s\n' % (i + start_iteration, msg))
return callback
def _get_num_workers(self):
return xgb.rabit.get_world_size()
def synchronize(self, data):
"""Synchronize data with the cluster.
This function allows every node to share state with every other node easily.
This allows things like determining which nodes have data or not.
:param data: data to send to the cluster
:return: aggregated data from the all the nodes in the cluster
"""
results = []
for i in range(rabit.get_world_size()):
if self.rank == i:
logging.debug("Broadcasting data from self ({}) to others".format(self.rank))
rabit.broadcast(data, i)
results.append(data)
else:
logging.debug("Receiving data from {}".format(i))
message = rabit.broadcast(None, i)
results.append(message)
return results