Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
],
name="matmul", assumptions="n,m,ell >= 1")
knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
knl = lp.split_iname(knl, "i", bsize, outer_tag="g.0", inner_tag="l.1")
knl = lp.split_iname(knl, "j", bsize, outer_tag="g.1", inner_tag="l.0")
knl = lp.split_iname(knl, "k", bsize)
knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"], default_tag="l.auto")
knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"], default_tag="l.auto")
n = 512
m = 256
ell = 128
params = {'n': n, 'm': m, 'ell': ell}
group_size = bsize*bsize
n_workgroups = div_ceil(n, bsize)*div_ceil(ell, bsize)
subgroups_per_group = div_ceil(group_size, SGS)
n_subgroups = n_workgroups*subgroups_per_group
sync_map = lp.get_synchronization_map(knl)
assert len(sync_map) == 2
assert sync_map["kernel_launch"].eval_with_dict(params) == 1
assert sync_map["barrier_local"].eval_with_dict(params) == 2*m/bsize
op_map = lp.get_op_map(knl, subgroup_size=SGS, count_redundant_work=True)
f32mul = op_map[
lp.Op(np.float32, 'mul', CG.SUBGROUP)
].eval_with_dict(params)
f32add = op_map[
lp.Op(np.float32, 'add', CG.SUBGROUP)
].eval_with_dict(params)
i32ops = op_map[
lp.Op(np.int32, 'add', CG.SUBGROUP)
],
name="matmul", assumptions="n,m,ell >= 1")
knl = lp.add_and_infer_dtypes(knl, dict(a=np.float32, b=np.float32))
knl = lp.split_iname(knl, "i", bsize, outer_tag="g.0", inner_tag="l.1")
knl = lp.split_iname(knl, "j", bsize, outer_tag="g.1", inner_tag="l.0")
knl = lp.split_iname(knl, "k", bsize)
# knl = lp.add_prefetch(knl, "a", ["k_inner", "i_inner"], default_tag="l.auto")
# knl = lp.add_prefetch(knl, "b", ["j_inner", "k_inner"], default_tag="l.auto")
n = 512
m = 256
ell = 128
params = {'n': n, 'm': m, 'ell': ell}
group_size = bsize*bsize
n_workgroups = div_ceil(n, bsize)*div_ceil(ell, bsize)
subgroups_per_group = div_ceil(group_size, SGS)
n_subgroups = n_workgroups*subgroups_per_group
mem_access_map = lp.get_mem_access_map(knl, count_redundant_work=True,
subgroup_size=SGS)
f32s1lb = mem_access_map[lp.MemAccess('global', np.float32,
lid_strides={0: 1},
gid_strides={1: bsize},
direction='load', variable='b',
variable_tag='mmbload',
count_granularity=CG.WORKITEM)
].eval_with_dict(params)
f32s1la = mem_access_map[lp.MemAccess('global', np.float32,
lid_strides={1: Variable('m')},
gid_strides={0: Variable('m')*bsize},
direction='load',
def get_dev_group_size(device):
# dirty fix for the RV770 boards
max_work_group_size = device.max_work_group_size
if "RV770" in device.name:
max_work_group_size = 64
# compute lmem limit
from pytools import div_ceil
lmem_wg_size = div_ceil(max_work_group_size, out_type_size)
result = min(max_work_group_size, lmem_wg_size)
# round down to power of 2
from pyopencl.tools import bitlog2
return 2**bitlog2(result)
def get_dev_group_size(device):
# dirty fix for the RV770 boards
max_work_group_size = device.max_work_group_size
if "RV770" in device.name:
max_work_group_size = 64
# compute lmem limit
from pytools import div_ceil
lmem_wg_size = div_ceil(max_work_group_size, out_type_size)
result = min(max_work_group_size, lmem_wg_size)
# round down to power of 2
from pyopencl.tools import bitlog2
return 2**bitlog2(result)
def get_dev_group_size(device):
# dirty fix for the RV770 boards
max_work_group_size = device.max_work_group_size
if "RV770" in device.name:
max_work_group_size = 64
# compute lmem limit
from pytools import div_ceil
lmem_wg_size = div_ceil(max_work_group_size, out_type_size)
result = min(max_work_group_size, lmem_wg_size)
# round down to power of 2
from pyopencl.tools import bitlog2
return 2**bitlog2(result)
def __call__(self, queue, tree, wait_for=None):
"""
:arg queue: a :class:`pyopencl.CommandQueue`
:arg tree: a :class:`boxtree.Tree`.
:arg wait_for: may either be *None* or a list of :class:`pyopencl.Event`
instances for whose completion this command waits before starting
execution.
:returns: a tuple *(pl, event)*, where *pl* is an instance of
:class:`PeerListLookup`, and *event* is a :class:`pyopencl.Event`
for dependency management.
"""
from pytools import div_ceil
# Round up level count--this gets included in the kernel as
# a stack bound. Rounding avoids too many kernel versions.
max_levels = div_ceil(tree.nlevels, 10) * 10
peer_list_finder_kernel = self.get_peer_list_finder_kernel(
tree.dimensions, tree.coord_dtype, tree.box_id_dtype, max_levels)
pl_plog = ProcessLogger(logger, "find peer lists")
result, evt = peer_list_finder_kernel(
queue, tree.nboxes,
tree.box_centers.data, tree.root_extent,
tree.box_levels.data, tree.aligned_nboxes,
tree.box_child_ids.data, tree.box_flags.data,
wait_for=wait_for)
pl_plog.done()
return PeerListLookup(