How to use the xgboost.rabit.get_rank function in xgboost

To help you get started, we’ve selected a few xgboost examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github dmlc / xgboost / tests / distributed / View on Github external
def run_test(name, params_fun):
    """Runs a distributed GPU test."""
    # Always call this before using distributed module
    rank = xgb.rabit.get_rank()
    world = xgb.rabit.get_world_size()

    # Load file, file will be automatically sharded in distributed mode.
    dtrain = xgb.DMatrix('../../demo/data/agaricus.txt.train')
    dtest = xgb.DMatrix('../../demo/data/agaricus.txt.test')

    params, n_rounds = params_fun(rank)

    # Specify validations set to watch performance
    watchlist = [(dtest, 'eval'), (dtrain, 'train')]

    # Run training, all the features in training API is available.
    # Currently, this script only support calling train once for fault recovery purpose.
    bst = xgb.train(params, dtrain, n_rounds, watchlist, early_stopping_rounds=2)

    # Have each worker save its model
github dmlc / xgboost / tests / distributed / View on Github external
X = np.array(X)
y = [1, 0]

dtrain = xgb.DMatrix(X, label=y)

param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic' }
watchlist  = [(dtrain,'train')]
num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist)

if xgb.rabit.get_rank() == 0:
  xgb.rabit.tracker_print("Finished training\n")

# Notify the tracker all training has been successful
# This is only needed in distributed training.
github mars-project / mars / mars / learn / contrib / xgboost / View on Github external
# non distributed
            local_history = dict()
            kwargs = dict() if op.kwargs is None else op.kwargs
            bst = train(params, dtrain, evals=evals,
                        evals_result=local_history, **kwargs)
            ctx[op.outputs[0].key] = {'booster': pickle.dumps(bst), 'history': local_history}
            # distributed
            rabit_args = ctx[op.tracker.key]
                local_history = dict()
                bst = train(params, dtrain, evals=evals, evals_result=local_history,
                ret = {'booster': pickle.dumps(bst), 'history': local_history}
                if rabit.get_rank() != 0:
                    ret = {}
                ctx[op.outputs[0].key] = ret
github aws / sagemaker-xgboost-container / src / sagemaker_xgboost_container / View on Github external

        if not successful_connection:
            self.logger.error("Failed to connect to Rabit Tracker after %s attempts", self.max_connect_attempts)
            raise Exception("Failed to connect to Rabit Tracker")
  "Connected to RabitTracker.")


        # We can check that the Rabit instance has successfully connected to the
        # server by getting the rank of the server (e.g. its position in the ring).
        # This should be unique for each instance.
        self.logger.debug("Rabit started - Rank {}".format(rabit.get_rank()))
        self.logger.debug("Executing user code")

        # We can now run user-code. Since XGBoost runs in the same process space
        # it will use the same instance of Rabit that we have configured. It has
        # a number of checks throughout the learning process to see if it is running
        # in distributed mode by calling Rabit APIs. If it is it will do the
        # synchronization automatically.
        # Hence we can now execute any XGBoost specific training code and it
        # will be distributed automatically.
        return RabitHelper(self.is_master_host, self.current_host, self.port)
github aws / sagemaker-xgboost-container / src / sagemaker_xgboost_container / View on Github external
def __init__(self, is_master, current_host, master_port):
        """This is returned by the Rabit context manager for useful cluster information and data synchronization.

        :param is_master:
        :param current_host:
        :param master_port:
        self.is_master = is_master
        self.rank = rabit.get_rank()
        self.current_host = current_host
        self.master_port = master_port
github awslabs / sagemaker-debugger / smdebug / xgboost / View on Github external
def _get_worker_name(self):
        return "worker_{}".format(xgb.rabit.get_rank())