Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def make_tuple_type(elt_types):
"""
Use this memoized construct to avoid
constructing too many distinct tuple type
objects and speeding up equality checks
"""
key = tuple(elt_types)
if key in _tuple_types:
return _tuple_types[key]
else:
t = TupleT(key)
_tuple_types[key] = t
return t
class ArrayT(StructT):
_members = ['elt_type', 'rank']
def finalize_init(self):
assert isinstance(self.elt_type, ScalarT)
tuple_t = repeat_tuple(Int64, self.rank)
self._field_types = [
('data', make_buffer_type(self.elt_type)),
('shape', tuple_t),
('strides', tuple_t),
]
self.ctypes_repr = ctypes_struct_from_fields(self._field_types)
def to_ctypes(self, x):
return self.ctypes_repr()
class StructT(ConcreteT):
"""
Structs must define how to translate themselves
from their normal python representation to a simplified
form involving only base types.
Any derived class *must* define a _field_types property
which contains a list of (name, field_type) pairs.
"""
pass
class TupleT(StructT):
rank = 0
_members = ['elt_types']
def finalize_init(self):
self._field_types = [
("elt%d" % i, t) for (i,t) in enumerate(self.elt_types)
]
self.ctypes_repr = ctypes_struct_from_fields(self._field_types)
def to_ctypes(self, python_tuple):
assert isinstance(python_tuple, tuple)
assert len(python_tuple) == len(self.elt_types)
converted_elts = []
for (elt_type, elt_value) in zip(self.elt_types, python_tuple):
converted_elts.append( elt_type.to_ctypes(elt_value) )
def dtype(self):
return self.elt_type.dtype()
def __eq__(self, other):
return isinstance(other, ArrayT) and \
self.elt_type == other.elt_type and self.rank == other.rank
def combine(self, other):
if self == other:
return self
else:
raise IncompatibleTypes(self, other)
class ClosureT(StructT):
_members = ['fn', 'args']
def finalize_init(self):
if self.args is None:
self.args = ()
elif not hasattr(self.args, '__iter__'):
self.args = tuple([self.args])
elif not isinstance(self.args, tuple):
self.args = tuple(self.args)
self._fields_types = [('fn_id', Int64)]
for (i, t) in enumerate(self.args):
self._field_types.append( ('arg%d' % i, t.ctypes_repr) )
self.ctypes_repr = ctypes_struct_from_fields(self._fields_types)