How to use the kymatio.Scattering2D function in kymatio

To help you get started, we’ve selected a few kymatio examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github kymatio / kymatio / examples / 2d / mnist.py View on Github external
scatter + linear achieves 99.15% in 15 epochs
        scatter + cnn achieves 99.3% in 15 epochs

    """
    parser = argparse.ArgumentParser(description='MNIST scattering  + hybrid examples')
    parser.add_argument('--mode', type=int, default=2,help='scattering 1st or 2nd order')
    parser.add_argument('--classifier', type=str, default='linear',help='classifier model')
    args = parser.parse_args()
    assert(args.classifier in ['linear','mlp','cnn'])

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if args.mode == 1:
        scattering = Scattering2D(J=2, shape=(28, 28), max_order=1)
        K = 17
    else:
        scattering = Scattering2D(J=2, shape=(28, 28))
        K = 81
    if use_cuda:
        scattering = scattering.cuda()




    if args.classifier == 'cnn':
        model = nn.Sequential(
            View(K, 7, 7),
            nn.BatchNorm2d(K),
            nn.Conv2d(K, 64, 3,padding=1), nn.ReLU(),
            View(64*7*7),
github kymatio / kymatio / examples / 2d / plot_invert_scattering_torch.py View on Github external
plt.imshow(src_img)
plt.title("Original image")

src_img = np.moveaxis(src_img, -1, 0)  # HWC to CHW
max_iter = 5 # number of steps for the GD
print("Image shape: ", src_img.shape)
channels, height, width = src_img.shape

###############################################################################
#  Main loop
# ----------
for order in [1]:
    for J in [2, 4]:

        # Compute scattering coefficients
        scattering = Scattering2D(J=J, shape=(height, width), max_order=order, frontend='torch')
        if device == "cuda":
            scattering = scattering.cuda()
            max_iter = 500
        src_img_tensor = torch.from_numpy(src_img).to(device).contiguous()
        scattering_coefficients = scattering(src_img_tensor)

        # Create trainable input image
        input_tensor = torch.rand(src_img.shape, requires_grad=True, device=device)

        # Optimizer hyperparams
        optimizer = optim.Adam([input_tensor], lr=1)

        # Training
        best_img = None
        best_loss = float("inf")
        for epoch in range(1, max_iter):
github kymatio / kymatio / examples / 2d / regularized_inverse_scattering_MNIST.py View on Github external
dir_to_save = get_cache_dir('reg_inverse_example')

    transforms_to_apply = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Pixel values should be in [-1,1]
    ])

    mnist_dir = get_dataset_dir("MNIST", create=True)
    dataset = datasets.MNIST(mnist_dir, train=True, download=True, transform=transforms_to_apply)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)

    fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    fixed_batch = next(iter(fixed_dataloader))
    fixed_batch = fixed_batch[0].float().cuda()

    scattering = Scattering(J=2, shape=(28, 28))
    scattering.cuda()

    scattering_fixed_batch = scattering(fixed_batch).squeeze(1)
    num_input_channels = scattering_fixed_batch.shape[1]
    num_hidden_channels = num_input_channels

    generator = Generator(num_input_channels, num_hidden_channels)
    generator.cuda()
    generator.train()

    # Either train the network or load a trained model
    ##################################################
    if load_model:
        filename_model = os.path.join(dir_to_save, 'model.pth')
        generator.load_state_dict(torch.load(filename_model))
    else:
github kymatio / kymatio / examples / 2d / plot_mnist_classify_torch.py View on Github external
data, target = data.to(device), target.to(device)
            output = model(scattering(data))
            pred = output.max(1, keepdim = True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    return 100. * correct / len(test_loader.dataset)
############################################################################
# Train a simple Hybrid Scattering + CNN model on MNIST.

from kymatio import Scattering2D
import torch.optim
import math


# Evaluate linear model on top of scattering
scattering = Scattering2D(shape = (28, 28), J=2, frontend='torch')
K = 81 #Number of output coefficients for each spatial postiion

if use_cuda:
    scattering = scattering.cuda()

model = nn.Sequential(
    View(K, 7, 7),
    nn.BatchNorm2d(K),
    View(K * 7 * 7),
    nn.Linear(K * 7 * 7, 10)
).to(device)

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9,
                            weight_decay=0.0005)
for epoch in range(0, 20):
github kymatio / kymatio / examples / 2d / cifar_resnet.py View on Github external
scatter 1st order +
        scatter 2nd order + linear achieves 70.5% in 90 epochs

        scatter + cnn achieves 88% in 15 epochs

    """
    parser = argparse.ArgumentParser(description='CIFAR scattering  + hybrid examples')
    parser.add_argument('--mode', type=int, default=1,help='scattering 1st or 2nd order')
    parser.add_argument('--width', type=int, default=2,help='width factor for resnet')
    args = parser.parse_args()

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if args.mode == 1:
        scattering = Scattering2D(J=2, shape=(32, 32), max_order=1)
        K = 17*3
    else:
        scattering = Scattering2D(J=2, shape=(32, 32))
        K = 81*3
    if use_cuda:
        scattering = scattering.cuda()




    model = Scattering2dResNet(K, args.width).to(device)

    # DataLoaders
    if use_cuda:
        num_workers = 4
        pin_memory = True
github kymatio / kymatio / examples / 2d / mnist.py View on Github external
"""
    parser = argparse.ArgumentParser(description='MNIST scattering  + hybrid examples')
    parser.add_argument('--mode', type=int, default=2,help='scattering 1st or 2nd order')
    parser.add_argument('--classifier', type=str, default='linear',help='classifier model')
    args = parser.parse_args()
    assert(args.classifier in ['linear','mlp','cnn'])

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if args.mode == 1:
        scattering = Scattering2D(J=2, shape=(28, 28), max_order=1)
        K = 17
    else:
        scattering = Scattering2D(J=2, shape=(28, 28))
        K = 81
    if use_cuda:
        scattering = scattering.cuda()




    if args.classifier == 'cnn':
        model = nn.Sequential(
            View(K, 7, 7),
            nn.BatchNorm2d(K),
            nn.Conv2d(K, 64, 3,padding=1), nn.ReLU(),
            View(64*7*7),
            nn.Linear(64 * 7 * 7, 512), nn.ReLU(),
            nn.Linear(512, 10)
        ).to(device)
github kymatio / kymatio / examples / 2d / cifar_resnet.py View on Github external
scatter + cnn achieves 88% in 15 epochs

    """
    parser = argparse.ArgumentParser(description='CIFAR scattering  + hybrid examples')
    parser.add_argument('--mode', type=int, default=1,help='scattering 1st or 2nd order')
    parser.add_argument('--width', type=int, default=2,help='width factor for resnet')
    args = parser.parse_args()

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if args.mode == 1:
        scattering = Scattering2D(J=2, shape=(32, 32), max_order=1)
        K = 17*3
    else:
        scattering = Scattering2D(J=2, shape=(32, 32))
        K = 81*3
    if use_cuda:
        scattering = scattering.cuda()




    model = Scattering2dResNet(K, args.width).to(device)

    # DataLoaders
    if use_cuda:
        num_workers = 4
        pin_memory = True
    else:
        num_workers = None
        pin_memory = False
github kymatio / kymatio / examples / 2d / cifar.py View on Github external
"""
    parser = argparse.ArgumentParser(description='MNIST scattering  + hybrid examples')
    parser.add_argument('--mode', type=int, default=1,help='scattering 1st or 2nd order')
    parser.add_argument('--classifier', type=str, default='cnn',help='classifier model')
    args = parser.parse_args()
    assert(args.classifier in ['linear','mlp','cnn'])

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    if args.mode == 1:
        scattering = Scattering2D(J=2, shape=(32, 32), max_order=1)
        K = 17*3
    else:
        scattering = Scattering2D(J=2, shape=(32, 32))
        K = 81*3
    if use_cuda:
        scattering = scattering.cuda()




    model = Scattering2dCNN(K,args.classifier).to(device)

    # DataLoaders
    if use_cuda:
        num_workers = 4
        pin_memory = True
    else:
        num_workers = None
        pin_memory = False