Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def make_tuple_type(elt_types):
"""
Use this memoized construct to avoid
constructing too many distinct tuple type
objects and speeding up equality checks
"""
key = tuple(elt_types)
if key in _tuple_types:
return _tuple_types[key]
else:
t = TupleT(key)
_tuple_types[key] = t
return t
class ArrayT(StructT):
_members = ['elt_type', 'rank']
def finalize_init(self):
assert isinstance(self.elt_type, ScalarT)
tuple_t = repeat_tuple(Int64, self.rank)
self._field_types = [
('data', make_buffer_type(self.elt_type)),
('shape', tuple_t),
('strides', tuple_t),
]
self.ctypes_repr = ctypes_struct_from_fields(self._field_types)
def to_ctypes(self, x):
return self.ctypes_repr()
class StructT(ConcreteT):
"""
Structs must define how to translate themselves
from their normal python representation to a simplified
form involving only base types.
Any derived class *must* define a _field_types property
which contains a list of (name, field_type) pairs.
"""
pass
class TupleT(StructT):
rank = 0
_members = ['elt_types']
def finalize_init(self):
self._field_types = [
("elt%d" % i, t) for (i,t) in enumerate(self.elt_types)
]
self.ctypes_repr = ctypes_struct_from_fields(self._field_types)
def to_ctypes(self, python_tuple):
assert isinstance(python_tuple, tuple)
assert len(python_tuple) == len(self.elt_types)
converted_elts = []
for (elt_type, elt_value) in zip(self.elt_types, python_tuple):
converted_elts.append( elt_type.to_ctypes(elt_value) )
def dtype(self):
return self.elt_type.dtype()
def __eq__(self, other):
return isinstance(other, ArrayT) and \
self.elt_type == other.elt_type and self.rank == other.rank
def combine(self, other):
if self == other:
return self
else:
raise IncompatibleTypes(self, other)
class ClosureT(StructT):
_members = ['fn', 'args']
def finalize_init(self):
if self.args is None:
self.args = ()
elif not hasattr(self.args, '__iter__'):
self.args = tuple([self.args])
elif not isinstance(self.args, tuple):
self.args = tuple(self.args)
self._fields_types = [('fn_id', Int64)]
for (i, t) in enumerate(self.args):
self._field_types.append( ('arg%d' % i, t.ctypes_repr) )
self.ctypes_repr = ctypes_struct_from_fields(self._fields_types)
def combine(self, other):
if isinstance(other, TupleT) and len(other.elt_types) == len(self.elt_types):
combined_elt_types = [t1.combine(t2) for \
(t1, t2) in zip(self.elt_types, other.elt_tyepes)]
if combined_elt_types != self.elt_types:
return TupleT(combined_elt_types)
else:
return self
else:
raise IncompatibleTypes(self, other)
def type_of_value(x):
if np.isscalar(x):
return type_of_scalar(x)
elif isinstance(x, tuple):
elt_types = map(type_of_value, x)
return TupleT(elt_types)
elif isinstance(x, np.ndarray):
return ArrayT(from_dtype(x.dtype), np.rank(x))
else:
raise RuntimeError("Unsupported type " + str(type(x)))
def combine(self, other):
if isinstance(other, TupleT) and len(other.elt_types) == len(self.elt_types):
combined_elt_types = [t1.combine(t2) for \
(t1, t2) in zip(self.elt_types, other.elt_tyepes)]
if combined_elt_types != self.elt_types:
return TupleT(combined_elt_types)
else:
return self
else:
raise IncompatibleTypes(self, other)
def repeat_tuple(t, n):
"""Given the base type t, construct the n-tuple t*t*...*t"""
elt_types = tuple([t] * n)
if elt_types in _tuple_types:
return _tuple_types[elt_types]
else:
tuple_t = TupleT(elt_types)
_tuple_types[elt_types] = tuple_t
return tuple_t
def make_tuple_type(elt_types):
"""
Use this memoized construct to avoid
constructing too many distinct tuple type
objects and speeding up equality checks
"""
key = tuple(elt_types)
if key in _tuple_types:
return _tuple_types[key]
else:
t = TupleT(key)
_tuple_types[key] = t
return t
def generic_value_to_scalar(gv, t):
assert isinstance(t, ptype.ScalarT), "Expected %s to be scalar" % t
if isinstance(t, ptype.IntT):
x = gv.as_int()
else:
assert isinstance(t, ptype.FloatT)
x = gv.as_real(dtype_to_lltype(t.dtype))
return t.dtype.type(x)
def scalar_to_generic_value(x, t):
if isinstance(t, ptype.FloatT):
return GenericValue.real(dtype_to_lltype(t.dtype), x)
elif t == ptype.Bool:
return GenericValue.int(int8_t, x)
else:
assert isinstance(t, ptype.IntT)
# assume it's an integer
return GenericValue.int(dtype_to_lltype(t.dtype), x)