How to use the mpi4py.MPI.COMM_WORLD.bcast function in mpi4py

To help you get started, we’ve selected a few mpi4py examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github firedrakeproject / firedrake / tests / output / test_hdf5file_checkpoint.py View on Github external
def test_checkpoint_fails_for_non_function(dumpfile):
    dumpfile = MPI.COMM_WORLD.bcast(dumpfile, root=0)
    with HDF5File(dumpfile, "w", comm=MPI.COMM_WORLD) as h5:
        with pytest.raises(ValueError):
            h5.write(np.arange(10), "/solution")
github olcf / pcircle / pcp.py View on Github external
def main():

    global ARGS, logger, circle
    signal.signal(signal.SIGINT, sig_handler)
    parse_flags = True

    if MPI.COMM_WORLD.rank == 0:
        try:
            ARGS = parse_args()
        except:
            parse_flags = False

    parse_flags = MPI.COMM_WORLD.bcast(parse_flags)

    if parse_flags:
        ARGS = MPI.COMM_WORLD.bcast(ARGS)
    else:
        sys.exit(0)

    circle = Circle(reduce_interval=ARGS.reduce_interval)
    circle.setLevel(logging.ERROR)
    logger = utils.logging_init(logger, ARGS.loglevel)

    pcp = None
    totalsize = None

    if ARGS.rid:
        pcp, totalsize = main_resume(ARGS.rid[0])
    else:
github jaredwo / topowx / twx / infill / mpi_infill_prcp_normals.py View on Github external
def proc_write(params,nwrkers):

    status = MPI.Status()
    stn_da = StationDataDb(params[P_PATH_DB],(params[P_START_YMD],params[P_END_YMD]))
    days = stn_da.days
    nwrkrs_done = 0
    
    bcast_msg = None
    bcast_msg = MPI.COMM_WORLD.bcast(bcast_msg, root=RANK_COORD)
    stnids_prcp = bcast_msg
    print "Writer: Received broadcast msg"
    
    
    if params[P_NCDF_MODE] == 'r+':
        
        ds_prcp = Dataset("".join([params[P_PATH_OUT],DS_NAME]),'r+')
        ttl_infills = stnids_prcp.size
        stnids_prcp = np.array(ds_prcp.variables['stn_id'][:], dtype="
github jaredwo / topowx / twx / infill / mpi_infill_optim_prcp.py View on Github external
def proc_work(params,rank):
    
    source_r()
    
    status = MPI.Status()
    stn_da = StationDataDb(params[P_PATH_DB],(params[P_START_YMD],params[P_END_YMD]))
    days = stn_da.days
    mth_masks = build_mth_masks(days)
    mthbuf_masks = build_mth_masks(days,MTH_BUFFER)
    yrmth_masks = build_yr_mth_masks(days)
    
    ds_prcp = Dataset(params[P_PATH_NORMS])
    
    bcast_msg = None
    bcast_msg = MPI.COMM_WORLD.bcast(bcast_msg, root=RANK_COORD)
    stn_ids,xval_masks_prcp = bcast_msg
    print "".join(["Worker ",str(rank),": Received broadcast msg"])
    
    def transfunct_bin(x,params=None):
        x = np.copy(x)
        x[x > 0] = 1
        return x,params
    
    def btrans_square(x,params=None):
        
        x = np.copy(x)
        x[x < 0] = 0
        return np.square(x)
    
    ################################################
    mean_center_prcp = False
github openai / large-scale-curiosity / utils.py View on Github external
def bcast_tf_vars_from_root(sess, vars):
    """
    Send the root node's parameters to every worker.

    Arguments:
      sess: the TensorFlow session.
      vars: all parameter variables including optimizer's
    """
    rank = MPI.COMM_WORLD.Get_rank()
    for var in vars:
        if rank == 0:
            MPI.COMM_WORLD.bcast(sess.run(var))
        else:
            sess.run(tf.assign(var, MPI.COMM_WORLD.bcast(None)))
github jaredwo / topowx / twx / interp / mpi_xval_po.py View on Github external
def proc_work(params,rank):
    
    status = MPI.Status()
    stn_da = StationSerialDataDb(params[P_PATH_DB], params[P_VARNAME])
    
    mod = ip.modeler_clib_po()
    po_interper = ip.interp_po(mod)
    
    bcast_msg = None
    bcast_msg = MPI.COMM_WORLD.bcast(bcast_msg, root=RANK_COORD)    
    print "".join(["Worker ",str(rank),": Received broadcast msg"])
    
    while 1:
    
        stn_id,min_ngh = MPI.COMM_WORLD.recv(source=RANK_COORD,tag=MPI.ANY_TAG,status=status)
        
        if status.tag == TAG_STOPWORK:
            MPI.COMM_WORLD.send([None]*3, dest=RANK_WRITE, tag=TAG_STOPWORK)
            print "".join(["Worker ",str(rank),": Finished"]) 
            return 0
        else:
            
            try:
                
                stn_slct = station_select(stn_da.stns,min_ngh,min_ngh+10)
github jaredwo / topowx / twx / infill / mpi_infill_optim_tair_normals.py View on Github external
def proc_work(params,rank):
    
    status = MPI.Status()
    stn_da = StationDataDb(params[P_PATH_DB],(params[P_START_YMD],params[P_END_YMD]))
    
    source_r(params[P_PATH_R_FUNCS])

    ds_nnr = NNRNghData(params[P_PATH_NNR], (params[P_START_YMD],params[P_END_YMD]))
        
    bcast_msg = None
    bcast_msg = MPI.COMM_WORLD.bcast(bcast_msg, root=RANK_COORD)
    
    stn_ids,xval_masks_tmin,xval_masks_tmax,mask_por_tmin,mask_por_tmax = bcast_msg
    xval_masks = {'tmin':xval_masks_tmin,'tmax':xval_masks_tmax}
    stn_masks = {'tmin':mask_por_tmin,'tmax':mask_por_tmax}
    
    aclib = clib_wxTopo('/home/jared.oyler/ecl_juno_workspace/wxtopo/wxTopo_C/Release/libwxTopo_C')
    
    print "".join(["Worker ",str(rank),": Received broadcast msg"])
        
    while 1:
    
        stn_id,nnghs,tair_var = MPI.COMM_WORLD.recv(source=RANK_COORD,tag=MPI.ANY_TAG,status=status)
        
        if status.tag == TAG_STOPWORK:
            MPI.COMM_WORLD.send([None]*6, dest=RANK_WRITE, tag=TAG_STOPWORK)
            print "".join(["Worker ",str(rank),": Finished"])
github jaredwo / topowx / twx / interp / mpi_interp_tair.py View on Github external
def proc_write(params,nwrkers):

    status = MPI.Status()
    tile_num_msg = np.zeros(1,dtype=np.int32)
    nwrkrs_done = 0
    
    bcast_msg = None
    bcast_msg = MPI.COMM_WORLD.bcast(bcast_msg, root=RANK_COORD)
    tile_grid_info = bcast_msg  
    tile_ids = tile_grid_info.tile_ids
    nchks = tile_grid_info.nchks
    chks_per_tile = tile_grid_info.chks_per_tile
    
    tile_status = {}
    for key in tile_ids.keys():
        tile_status[key] = 0
    
    tile_queues = {}
    for key in tile_ids.keys():
        tile_queues[key] = deque()
    
    stat_chk = StatusCheck(nchks,1)
    
    while 1:
github flowersteam / curious / baselines / her / rollout.py View on Github external
if self.p.sum() > 1:
                            self.p[np.argmax(self.p)] -= self.p.sum() - 1
                        elif self.p.sum() < 1:
                            self.p[-1] = 1 - self.p[:-1].sum()


                    elif self.structure == 'task_experts':
                        self.p = np.zeros([self.nb_tasks])
                        self.p[self.unique_task] = 1


            # broadcast the selection probability to all cpus and the competence
            if not self.eval:
                self.p = MPI.COMM_WORLD.bcast(self.p, root=0)
                self.CP = MPI.COMM_WORLD.bcast(self.CP, root=0)

        return convert_episode_to_batch_major(episode), self.CP, self.n_episodes
github jaredwo / topowx / twx / infill / mpi_infill_optim_tair_normals.py View on Github external
tmin_stn = tmin[:,x]
        tmax_stn = tmax[:,x]
        
        tmin_stn[xval_mask_tmin] = np.nan
        tmax_stn[xval_mask_tmax] = np.nan
        
        tmin[:,x] = tmin_stn
        tmax[:,x] = tmax_stn
    
    #Load the period-of-record datafile
    por = load_por_csv(params[P_PATH_POR])
    mask_por_tmin,mask_por_tmax = build_valid_por_masks(por,params[P_MIN_POR],params[P_STN_LOC_BNDS])[0:2]
    
    #Send stn ids and masks to all processes
    MPI.COMM_WORLD.bcast((fnl_stn_ids,xval_masks_tmin,xval_masks_tmax,mask_por_tmin,mask_por_tmax), root=RANK_COORD)
    
    stn_idxs = {}
    for x in np.arange(fnl_stn_ids.size):
        stn_idxs[fnl_stn_ids[x]] = x
    
    print "Coord: Done initialization. Starting to send work."
    
    cnt = 0
    nrec = 0
    
    for stn_id in fnl_stn_ids:
        
        for min_ngh in params[P_NGH_RNG]:
            
            for tair_var in ['tmin','tmax']: