Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""This tests our ability to simplify torch.Tensor objects
using "torch" serialization strategy.
At the time of writing, tensors simplify to a tuple where the
first value in the tuple is the tensor's ID and the second
value is a serialized version of the Tensor (serialized
by PyTorch's torch.save method)
"""
me = workers["me"]
# create a tensor
input = Tensor(numpy.random.random((100, 100)))
# simplify the tnesor
output = msgpack.serde._simplify(me, input)
# make sure outer type is correct
assert type(output) == tuple
# make sure the object type ID is correct
# (0 for torch.Tensor)
assert msgpack.serde.detailers[output[0]] == torch_serde._detail_torch_tensor
# make sure inner type is correct
assert type(output[1]) == tuple
# make sure ID is correctly encoded
assert output[1][0] == input.id
# make sure tensor data type is correct
assert type(output[1][1]) == bytes
def test_set_simplify(workers):
"""This tests our ability to simplify set objects.
This test is pretty simple since sets just serialize to
lists, with a tuple wrapper with the correct ID (3)
for sets so that the detailer knows how to interpret it."""
me = workers["me"]
input = set(["hello", "world"])
set_detail_code = msgpack.proto_type_info(set).code
str_detail_code = msgpack.proto_type_info(str).code
target = (set_detail_code, ((str_detail_code, (b"hello",)), (str_detail_code, (b"world",))))
assert msgpack.serde._simplify(me, input)[0] == target[0]
assert set(msgpack.serde._simplify(me, input)[1]) == set(target[1])
return [
{
"value": plan,
"simplified": (
CODE[syft.messaging.plan.plan.Plan],
(
plan.id, # (int or str) id
msgpack.serde._simplify(syft.hook.local_worker, plan.procedure), # (Procedure)
msgpack.serde._simplify(syft.hook.local_worker, plan.state), # (State)
plan.include_state, # (bool) include_state
plan.is_built, # (bool) is_built
(CODE[list], ((CODE[torch.Size], (3,)),)), # (list of torch.Size) input_shapes
# NOTE: it's uninitialized until plan.output_shape property is used
None, # (torch.Size) _output_shape
msgpack.serde._simplify(syft.hook.local_worker, plan.name), # (str) name
msgpack.serde._simplify(syft.hook.local_worker, plan.tags), # (set of str) tags
msgpack.serde._simplify(
syft.hook.local_worker, plan.description
), # (str) description
msgpack.serde._simplify(syft.hook.local_worker, []), # (list of State)
),
),
"cmp_detailed": compare,
},
{
"value": model_plan,
"simplified": (
CODE[syft.messaging.plan.plan.Plan],
(
model_plan.id, # (int or str) id
msgpack.serde._simplify(
def simplify(worker: AbstractWorker, tensor: "PromiseTensor") -> tuple:
"""Takes the attributes of a PromiseTensor and saves them in a tuple.
Args:
tensor: a PromiseTensor.
Returns:
tuple: a tuple holding the unique attributes of the Promise tensor.
"""
return (
sy.serde.msgpack.serde._simplify(worker, tensor.id),
sy.serde.msgpack.serde._simplify(worker, tensor.shape),
sy.serde.msgpack.serde._simplify(worker, tensor.obj_type),
sy.serde.msgpack.serde._simplify(worker, tensor.plans),
sy.serde.msgpack.serde._simplify(worker, tensor.tags),
sy.serde.msgpack.serde._simplify(worker, tensor.description),
)
# I think the pointer bug is is between here
if hasattr(tensor, "child"):
chain = serde._simplify(worker, tensor.child)
# and here... leaving a reerence here so i can find it later
# TODO fix pointer bug
return (
tensor.id,
tensor_bin,
chain,
grad_chain,
serde._simplify(worker, tensor.tags),
serde._simplify(worker, tensor.description),
serde._simplify(worker, worker.serializer),
)
Args:
worker (AbstractWorker): the worker doing the serialization
plan (Plan): a Plan object
Returns:
tuple: a tuple holding the unique attributes of the Plan object
"""
return (
sy.serde.msgpack.serde._simplify(worker, plan.id),
sy.serde.msgpack.serde._simplify(worker, plan.procedure),
sy.serde.msgpack.serde._simplify(worker, plan.state),
sy.serde.msgpack.serde._simplify(worker, plan.include_state),
sy.serde.msgpack.serde._simplify(worker, plan.is_built),
sy.serde.msgpack.serde._simplify(worker, plan.input_shapes),
sy.serde.msgpack.serde._simplify(worker, plan._output_shape),
sy.serde.msgpack.serde._simplify(worker, plan.name),
sy.serde.msgpack.serde._simplify(worker, plan.tags),
sy.serde.msgpack.serde._simplify(worker, plan.description),
sy.serde.msgpack.serde._simplify(worker, plan.nested_states),
)
def simplify(worker: AbstractWorker, ptr: "PointerPlan") -> tuple:
return (
sy.serde.msgpack.serde._simplify(worker, ptr.id),
sy.serde.msgpack.serde._simplify(worker, ptr.id_at_location),
sy.serde.msgpack.serde._simplify(worker, ptr.location.id),
ptr.garbage_collect_data,
)
def simplify(worker: AbstractWorker, tensor: "PrivateTensor") -> tuple:
"""Takes the attributes of a PrivateTensor and saves them in a tuple.
Args:
tensor (PrivateTensor): a PrivateTensor.
Returns:
tuple: a tuple holding the unique attributes of the fixed private tensor.
"""
chain = None
if hasattr(tensor, "child"):
chain = syft.serde.msgpack.serde._simplify(worker, tensor.child)
return (
syft.serde.msgpack.serde._simplify(worker, tensor.id),
syft.serde.msgpack.serde._simplify(worker, tensor.allowed_users),
syft.serde.msgpack.serde._simplify(worker, tensor.tags),
syft.serde.msgpack.serde._simplify(worker, tensor.description),
chain,
)