Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_remapped_labels(self):
# Test additional label keys in the Y_dict
# Without remapping, model should ignore them
task_name = self.task1.name
X = torch.FloatTensor([[i, i] for i in range(NUM_EXAMPLES)])
Y = torch.ones(NUM_EXAMPLES).long()
Y_dict = {task_name: Y, "other_task": Y}
dataset = DictDataset(
name="dataset", split="train", X_dict={"data": X}, Y_dict=Y_dict
)
dataloader = DictDataLoader(dataset, batch_size=BATCH_SIZE)
model = MultitaskClassifier([self.task1])
loss_dict, count_dict = model.calculate_loss(dataset.X_dict, dataset.Y_dict)
self.assertIn("task1", loss_dict)
# Test setting without remapping
results = model.predict(dataloader)
self.assertIn("task1", results["golds"])
self.assertNotIn("other_task", results["golds"])
scores = model.score([dataloader])
self.assertIn("task1/dataset/train/accuracy", scores)
self.assertNotIn("other_task/dataset/train/accuracy", scores)
# Test remapped labelsets
results = model.predict(dataloader, remap_labels={"other_task": task_name})
self.assertIn("task1", results["golds"])
def create_dataloader(task_name="task", split="train", **kwargs):
X = torch.FloatTensor([[i, i] for i in range(NUM_EXAMPLES)])
Y = torch.ones(NUM_EXAMPLES).long()
dataset = DictDataset(
name="dataset", split=split, X_dict={"data": X}, Y_dict={task_name: Y}
)
dataloader = DictDataLoader(dataset, batch_size=BATCH_SIZE, **kwargs)
return dataloader
torch.Tensor([1, 2, 3, 4]),
torch.Tensor([1, 2, 3]),
torch.Tensor([1, 2]),
torch.Tensor([1]),
]
y2 = torch.Tensor([1, 1, 1, 1, 1])
dataset = DictDataset(
name="new_data",
split="train",
X_dict={"data1": x1, "data2": x2},
Y_dict={"task1": y1, "task2": y2},
)
dataloader1 = DictDataLoader(dataset=dataset, batch_size=2)
x_batch, y_batch = next(iter(dataloader1))
# Check if the dataloader is correctly constructed
self.assertEqual(dataloader1.dataset.split, "train")
self.assertTrue(torch.equal(x_batch["data1"], torch.Tensor([[1, 0], [1, 2]])))
self.assertTrue(
torch.equal(
x_batch["data2"], torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0]])
)
)
self.assertTrue(torch.equal(y_batch["task1"], torch.Tensor([0, 0])))
self.assertTrue(torch.equal(y_batch["task2"], torch.Tensor([1, 1])))
dataloader2 = DictDataLoader(dataset=dataset, batch_size=3)
def create_dataloader(df: pd.DataFrame, split: str) -> DictDataLoader:
dataset = DictDataset(
name="TestData",
split=split,
X_dict={
"coordinates": torch.stack(
(torch.tensor(df["x1"]), torch.tensor(df["x2"])), dim=1
)
},
Y_dict={"task": torch.tensor(df["y"], dtype=torch.long)},
)
dataloader = DictDataLoader(
dataset=dataset, batch_size=4, shuffle=(dataset.split == "train")
)
return dataloader
dataset1 = DictDataset(
"d1",
"train",
X_dict={"data": [0, 1, 2, 3, 4]},
Y_dict={"labels": torch.LongTensor([1, 1, 1, 1, 1])},
)
dataset2 = DictDataset(
"d2",
"train",
X_dict={"data": [5, 6, 7, 8, 9]},
Y_dict={"labels": torch.LongTensor([2, 2, 2, 2, 2])},
)
dataloader1 = DictDataLoader(dataset1, batch_size=2)
dataloader2 = DictDataLoader(dataset2, batch_size=2)
dataloaders = [dataloader1, dataloader2]
class SequentialTest(unittest.TestCase):
def test_sequential(self):
scheduler = SequentialScheduler()
data = []
for (batch, dl) in scheduler.get_batches(dataloaders):
X_dict, Y_dict = batch
data.extend(X_dict["data"])
self.assertEqual(data, sorted(data))
def test_shuffled(self):
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
dataloader1 = DictDataLoader(dataset=dataset, batch_size=2)
x_batch, y_batch = next(iter(dataloader1))
# Check if the dataloader is correctly constructed
self.assertEqual(dataloader1.dataset.split, "train")
self.assertTrue(torch.equal(x_batch["data1"], torch.Tensor([[1, 0], [1, 2]])))
self.assertTrue(
torch.equal(
x_batch["data2"], torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0]])
)
)
self.assertTrue(torch.equal(y_batch["task1"], torch.Tensor([0, 0])))
self.assertTrue(torch.equal(y_batch["task2"], torch.Tensor([1, 1])))
dataloader2 = DictDataLoader(dataset=dataset, batch_size=3)
x_batch, y_batch = next(iter(dataloader2))
# Check if the dataloader with differet batch size is correctly constructed
self.assertEqual(dataloader2.dataset.split, "train")
self.assertTrue(
torch.equal(
x_batch["data1"], torch.Tensor([[1, 0, 0], [1, 2, 0], [1, 2, 3]])
)
)
self.assertTrue(
torch.equal(
x_batch["data2"],
torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0], [1, 2, 3, 0, 0]]),
)
)
from model import SceneGraphDataset, create_model
df_train["labels"] = label_model.predict(L_train)
if sample:
TRAIN_DIR = "data/VRD/sg_dataset/samples"
else:
TRAIN_DIR = "data/VRD/sg_dataset/sg_train_images"
dl_train = DictDataLoader(
SceneGraphDataset("train_dataset", "train", TRAIN_DIR, df_train),
batch_size=16,
shuffle=True,
)
dl_valid = DictDataLoader(
SceneGraphDataset("valid_dataset", "valid", TRAIN_DIR, df_valid),
batch_size=16,
shuffle=False,
)
# %% [markdown]
# #### Define Model Architecture
# %%
import torchvision.models as models
# initialize pretrained feature extractor
cnn = models.resnet18(pretrained=True)
model = create_model(cnn)
# %% [markdown]
# `DictDataloader` is a wrapper for `torch.utils.data.Dataloader`, which handles the collate function for `DictDataset` appropriately.
# %%
from snorkel.classification import DictDataset, DictDataLoader
dataloaders = []
for task_name in ["circle", "square"]:
for split, X, Y in (
("train", X_train, Y_train),
("valid", X_valid, Y_valid),
("test", X_test, Y_test),
):
X_dict = {f"{task_name}_data": torch.FloatTensor(X[task_name])}
Y_dict = {f"{task_name}_task": torch.LongTensor(Y[task_name])}
dataset = DictDataset(f"{task_name}Dataset", split, X_dict, Y_dict)
dataloader = DictDataLoader(dataset, batch_size=32)
dataloaders.append(dataloader)
# %% [markdown]
# We now have 6 data loaders, one for each split (`train`, `valid`, `test`) of each task (`circle_task` and `square_task`).
# %% [markdown]
# ## Define Model
# %% [markdown]
# Now we'll define the `MultitaskClassifier` model, a PyTorch multi-task classifier.
# We'll instantiate it from a list of `Tasks`.
# %% [markdown]
# ### Tasks
# %% [markdown]