How to use the snorkel.classification.Trainer function in snorkel

To help you get started, we’ve selected a few snorkel examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github snorkel-team / snorkel / test / classification / training / test_trainer.py View on Github external
def test_warmup(self):
        lr_scheduler_config = {"warmup_steps": 1, "warmup_unit": "batches"}
        trainer = Trainer(**base_config, lr_scheduler_config=lr_scheduler_config)
        trainer.fit(model, [dataloaders[0]])
        self.assertEqual(trainer.warmup_steps, 1)

        lr_scheduler_config = {"warmup_steps": 1, "warmup_unit": "epochs"}
        trainer = Trainer(**base_config, lr_scheduler_config=lr_scheduler_config)
        trainer.fit(model, [dataloaders[0]])
        self.assertEqual(trainer.warmup_steps, BATCHES_PER_EPOCH)

        lr_scheduler_config = {"warmup_percentage": 1 / BATCHES_PER_EPOCH}
        trainer = Trainer(**base_config, lr_scheduler_config=lr_scheduler_config)
        trainer.fit(model, [dataloaders[0]])
        self.assertEqual(trainer.warmup_steps, 1)
github snorkel-team / snorkel / test / classification / training / test_trainer.py View on Github external
def test_optimizer_init(self):
        trainer = Trainer(**base_config, optimizer="sgd")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.optimizer, optim.SGD)

        trainer = Trainer(**base_config, optimizer="adam")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.optimizer, optim.Adam)

        trainer = Trainer(**base_config, optimizer="adamax")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.optimizer, optim.Adamax)

        with self.assertRaisesRegex(ValueError, "Unrecognized optimizer"):
            trainer = Trainer(**base_config, optimizer="foo")
            trainer.fit(model, [dataloaders[0]])
github snorkel-team / snorkel / test / classification / training / test_trainer.py View on Github external
def test_log_writer_json(self):
        # Addresses issue #1439
        # Confirm that a log file is written to the specified location after training
        run_name = "log.json"
        with tempfile.TemporaryDirectory() as temp_dir:
            log_writer_config = {"log_dir": temp_dir, "run_name": run_name}
            trainer = Trainer(
                **base_config,
                logging=True,
                log_writer="json",
                log_writer_config=log_writer_config,
            )
            trainer.fit(model, [dataloaders[0]])
            log_path = os.path.join(trainer.log_writer.log_dir, run_name)
            with open(log_path, "r") as f:
                log = json.load(f)
            self.assertIn("model/all/train/loss", log)
github snorkel-team / snorkel / test / classification / test_classifier_convergence.py View on Github external
for offset, task_name in zip([0.0, 0.25], ["task1", "task2"]):
            df = create_data(N_TRAIN, offset)
            dataloader = create_dataloader(df, "train", task_name)
            dataloaders.append(dataloader)

        for offset, task_name in zip([0.0, 0.25], ["task1", "task2"]):
            df = create_data(N_VALID, offset)
            dataloader = create_dataloader(df, "valid", task_name)
            dataloaders.append(dataloader)

        task1 = create_task("task1", module_suffixes=["A", "A"])
        task2 = create_task("task2", module_suffixes=["A", "B"])
        model = MultitaskClassifier(tasks=[task1, task2])

        # Train
        trainer = Trainer(lr=0.001, n_epochs=10, progress_bar=False)
        trainer.fit(model, dataloaders)
        scores = model.score(dataloaders)

        # Confirm near perfect scores on both tasks
        for idx, task_name in enumerate(["task1", "task2"]):
            self.assertGreater(scores[f"{task_name}/TestData/valid/accuracy"], 0.95)

            # Calculate/check train/val loss
            train_dataset = dataloaders[idx].dataset
            train_loss_output = model.calculate_loss(
                train_dataset.X_dict, train_dataset.Y_dict
            )
            train_loss = train_loss_output[0][task_name].item()
            self.assertLess(train_loss, 0.05)

            val_dataset = dataloaders[2 + idx].dataset
github snorkel-team / snorkel / test / classification / training / test_trainer.py View on Github external
def test_log_writer_init(self):
        with tempfile.TemporaryDirectory() as temp_dir:
            log_writer_config = {"log_dir": temp_dir}
            trainer = Trainer(
                **base_config,
                logging=True,
                log_writer="json",
                log_writer_config=log_writer_config,
            )
            trainer.fit(model, [dataloaders[0]])
            self.assertIsInstance(trainer.log_writer, LogWriter)

            log_writer_config = {"log_dir": temp_dir}
            trainer = Trainer(
                **base_config,
                logging=True,
                log_writer="tensorboard",
                log_writer_config=log_writer_config,
            )
            trainer.fit(model, [dataloaders[0]])
            self.assertIsInstance(trainer.log_writer, TensorBoardWriter)

            log_writer_config = {"log_dir": temp_dir}
            with self.assertRaisesRegex(ValueError, "Unrecognized writer"):
                trainer = Trainer(
                    **base_config,
                    logging=True,
                    log_writer="foo",
                    log_writer_config=log_writer_config,
                )
github snorkel-team / snorkel / test / classification / training / test_trainer.py View on Github external
trainer.fit(model, [dataloaders[0]])
            self.assertIsInstance(trainer.log_writer, LogWriter)

            log_writer_config = {"log_dir": temp_dir}
            trainer = Trainer(
                **base_config,
                logging=True,
                log_writer="tensorboard",
                log_writer_config=log_writer_config,
            )
            trainer.fit(model, [dataloaders[0]])
            self.assertIsInstance(trainer.log_writer, TensorBoardWriter)

            log_writer_config = {"log_dir": temp_dir}
            with self.assertRaisesRegex(ValueError, "Unrecognized writer"):
                trainer = Trainer(
                    **base_config,
                    logging=True,
                    log_writer="foo",
                    log_writer_config=log_writer_config,
                )
                trainer.fit(model, [dataloaders[0]])
github snorkel-team / snorkel / test / classification / training / test_trainer.py View on Github external
def test_scheduler_init(self):
        trainer = Trainer(**base_config, lr_scheduler="constant")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsNone(trainer.lr_scheduler)

        trainer = Trainer(**base_config, lr_scheduler="linear")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.lr_scheduler, optim.lr_scheduler.LambdaLR)

        trainer = Trainer(**base_config, lr_scheduler="exponential")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.lr_scheduler, optim.lr_scheduler.ExponentialLR)

        trainer = Trainer(**base_config, lr_scheduler="step")
        trainer.fit(model, [dataloaders[0]])
        self.assertIsInstance(trainer.lr_scheduler, optim.lr_scheduler.StepLR)

        with self.assertRaisesRegex(ValueError, "Unrecognized lr scheduler"):
            trainer = Trainer(**base_config, lr_scheduler="foo")
            trainer.fit(model, [dataloaders[0]])
github snorkel-team / snorkel / test / classification / training / test_trainer.py View on Github external
def test_trainer_twotask(self):
        """Train a model with overlapping modules and flows"""
        multitask_model = MultitaskClassifier(tasks)
        trainer = Trainer(**base_config)
        trainer.fit(multitask_model, dataloaders)
github snorkel-team / snorkel-tutorials / visual_relation / visual_relation_tutorial.py View on Github external
# #### Define Model Architecture

# %%
import torchvision.models as models

# initialize pretrained feature extractor
cnn = models.resnet18(pretrained=True)
model = create_model(cnn)

# %% [markdown]
# ### Train and Evaluate Model

# %% {"tags": ["md-exclude-output"]}
from snorkel.classification import Trainer

trainer = Trainer(
    n_epochs=1,  # increase for improved performance
    lr=1e-3,
    checkpointing=True,
    checkpointer_config={"checkpoint_dir": "checkpoint"},
)
trainer.fit(model, [dl_train])

# %%
model.score([dl_valid])
github snorkel-team / snorkel-tutorials / spam / 03_spam_data_slicing_tutorial.py View on Github external
test_dl = create_dict_dataloader(X_test, Y_test, "train")
test_dl_slice = slice_model.make_slice_dataloader(
    test_dl.dataset, S_test, shuffle=False, batch_size=BATCH_SIZE
)

# %% [markdown]
# ### Representation learning with slices

# %% [markdown]
# Using Snorkel's [`Trainer`](https://snorkel.readthedocs.io/en/master/packages/_autosummary/classification/snorkel.classification.Trainer.html), we fit our classifier with the training set dataloader.

# %%
from snorkel.classification import Trainer

# For demonstration purposes, we set n_epochs=2
trainer = Trainer(n_epochs=2, lr=1e-4, progress_bar=True)
trainer.fit(slice_model, [train_dl_slice])

# %% [markdown]
# At inference time, the primary task head (`spam_task`) will make all final predictions.
# We'd like to evaluate all the slice heads on the original task head — [`score_slices`](https://snorkel.readthedocs.io/en/v0.9.3/packages/_autosummary/slicing/snorkel.slicing.SliceAwareClassifier.html#snorkel.slicing.SliceAwareClassifier.score_slices) remaps all slice-related labels, denoted `spam_task_slice:{slice_name}_pred`, to be evaluated on the `spam_task`.

# %%
slice_model.score_slices([test_dl_slice], as_dataframe=True)