Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def generate_kernel_wrapper(self, func, argtypes):
module = func.module
arginfo = self.get_arg_packer(argtypes)
def sub_gen_with_global(lty):
if isinstance(lty, llvmir.PointerType):
return (lty.pointee.as_pointer(SPIR_GLOBAL_ADDRSPACE),
lty.addrspace)
return lty, None
if len(arginfo.argument_types) > 0:
llargtys, changed = zip(*map(sub_gen_with_global,
arginfo.argument_types))
else:
llargtys = changed = ()
wrapperfnty = lc.Type.function(lc.Type.void(), llargtys)
wrapper_module = self.create_module("hsa.kernel.wrapper")
wrappername = 'hsaPy_{name}'.format(name=func.name)
argtys = list(arginfo.argument_types)
fnty = lc.Type.function(lc.Type.int(),
[self.call_conv.get_return_type(
types.pyobject)] + argtys)
func = wrapper_module.add_function(fnty, name=func.name)
func.calling_convention = CC_SPIR_FUNC
wrapper = wrapper_module.add_function(wrapperfnty, name=wrappername)
builder = lc.Builder(wrapper.append_basic_block(''))
builder, func, types.void, argtypes, callargs)
if debug:
# Check error status
with cgutils.if_likely(builder, status.is_ok):
builder.ret_void()
with builder.if_then(builder.not_(status.is_python_exc)):
# User exception raised
old = Constant.null(gv_exc.type.pointee)
# Use atomic cmpxchg to prevent rewriting the error status
# Only the first error is recorded
casfnty = lc.Type.function(old.type, [gv_exc.type, old.type,
old.type])
casfn = wrapper_module.add_function(casfnty,
name="___numba_cas_hack")
xchg = builder.call(casfn, [gv_exc, old, status.code])
changed = builder.icmp(ICMP_EQ, xchg, old)
# If the xchange is successful, save the thread ID.
sreg = nvvmutils.SRegBuilder(builder)
with builder.if_then(changed):
for dim, ptr, in zip("xyz", gv_tid):
val = sreg.tid(dim)
builder.store(val, ptr)
for dim, ptr, in zip("xyz", gv_ctaid):
val = sreg.ctaid(dim)
def call_sreg(builder, name):
module = builder.module
fnty = lc.Type.function(lc.Type.int(), ())
fn = module.get_or_insert_function(fnty, name=SREG_MAPPING[name])
return builder.call(fn, ())
# note that the result value pointer as first argument is the convention
# used by numba.
# First, prepare the return value
out = context.make_complex(builder, ty)
ptrargs = [cgutils.alloca_once_value(builder, arg)
for arg in args]
call_args = [out._getpointer()] + ptrargs
# get_value_as_argument for struct types like complex allocate stack space
# and initialize with the value, the return value is the pointer to that
# allocated space (ie: pointer to a copy of the value in the stack).
# get_argument_type returns a pointer to the struct type in consonance.
call_argtys = [ty] + list(sig.args)
call_argltys = [context.get_value_type(ty).as_pointer()
for ty in call_argtys]
fnty = lc.Type.function(lc.Type.void(), call_argltys)
# Note: the function isn't pure here (it writes to its pointer args)
fn = mod.get_or_insert_function(fnty, name=func_name)
builder.call(fn, call_args)
retval = builder.load(call_args[0])
else:
argtypes = [context.get_argument_type(aty) for aty in sig.args]
restype = context.get_argument_type(sig.return_type)
fnty = lc.Type.function(restype, argtypes)
fn = cgutils.insert_pure_function(mod, fnty, name=func_name)
retval = context.call_external_function(builder, fn, sig.args, args)
return retval
def generate_kernel_wrapper(self, library, fname, argtypes, debug):
"""
Generate the kernel wrapper in the given ``library``.
The function being wrapped have the name ``fname`` and argument types
``argtypes``. The wrapper function is returned.
"""
arginfo = self.get_arg_packer(argtypes)
argtys = list(arginfo.argument_types)
wrapfnty = Type.function(Type.void(), argtys)
wrapper_module = self.create_module("cuda.kernel.wrapper")
fnty = Type.function(Type.int(),
[self.call_conv.get_return_type(types.pyobject)] + argtys)
func = wrapper_module.add_function(fnty, name=fname)
prefixed = itanium_mangler.prepend_namespace(func.name, ns='cudapy')
wrapfn = wrapper_module.add_function(wrapfnty, name=prefixed)
builder = Builder(wrapfn.append_basic_block(''))
# Define error handling variables
def define_error_gv(postfix):
gv = wrapper_module.add_global_variable(Type.int(),
name=wrapfn.name + postfix)
gv.initializer = Constant.null(gv.type.pointee)
return gv
function signature of the symbol being declared
cargs: sequence of str
C type names for the arguments
mangler: a mangler function
function to use to mangle the symbol
"""
mod = builder.module
if sig.return_type == types.void:
llretty = lc.Type.void()
else:
llretty = context.get_value_type(sig.return_type)
llargs = [context.get_value_type(t) for t in sig.args]
fnty = Type.function(llretty, llargs)
mangled = mangler(name, cargs)
fn = mod.get_or_insert_function(fnty, mangled)
fn.calling_convention = target.CC_SPIR_FUNC
return fn
def declare_atomic_max_float64(lmod):
fname = '___numba_atomic_double_max'
fnty = lc.Type.function(lc.Type.double(),
(lc.Type.pointer(lc.Type.double()), lc.Type.double()))
return lmod.get_or_insert_function(fnty, fname)
gil_state = pyapi.gil_ensure()
thread_state = pyapi.save_thread()
def as_void_ptr(arg):
return builder.bitcast(arg, byte_ptr_t)
# Array count is input signature plus 1 (due to output array)
array_count = len(sig.args) + 1
parallel_for_ty = lc.Type.function(lc.Type.void(),
[byte_ptr_t] * 5 + [intp_t, ] * 2)
parallel_for = mod.get_or_insert_function(parallel_for_ty,
name='numba_parallel_for')
# Reference inner-function and link
innerfunc_fnty = lc.Type.function(
lc.Type.void(),
[byte_ptr_ptr_t, intp_ptr_t, intp_ptr_t, byte_ptr_t],
)
tmp_voidptr = mod.get_or_insert_function(
innerfunc_fnty, name=info.name,
)
wrapperlib.add_linking_library(info.library)
# Prepare call
fnptr = builder.bitcast(tmp_voidptr, byte_ptr_t)
innerargs = [as_void_ptr(x) for x
in [args, dimensions, steps, data]]
builder.call(parallel_for, [fnptr] + innerargs +
[intp_t(x) for x in (inner_ndim, array_count)])
# Release the GIL
def details(context, builder, signature, args):
ll_Py_UCS4 = context.get_value_type(_Py_UCS4)
ll_intc = context.get_value_type(types.intc)
fnty = lc.Type.function(ll_Py_UCS4, [ll_intc])
fn = builder.module.get_or_insert_function(
fnty, name="numba_get_PyUnicode_ExtendedCase")
return builder.call(fn, [args[0]])