How to use torchfcn - 10 common examples

To help you get started, we’ve selected a few torchfcn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wkentaro / pytorch-fcn / examples / apc / train_fcn32s.py View on Github external
cuda = torch.cuda.is_available()

    batch_size = torch.cuda.device_count() * 3
    max_iter = cfg['max_iteration'] // batch_size

    torch.manual_seed(1)
    if cuda:
        torch.cuda.manual_seed(1)

    # 1. dataset

    cfg['dataset'] = cfg.get('dataset', 'v2')
    if cfg['dataset'] == 'v2':
        dataset_class = torchfcn.datasets.APC2016V2
    elif cfg['dataset'] == 'v3':
        dataset_class = torchfcn.datasets.APC2016V3
    else:
        raise ValueError('Unsupported dataset: %s' % cfg['dataset'])

    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = torch.utils.data.DataLoader(
        dataset_class(split='train', transform=True),
        batch_size=batch_size, shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(
        dataset_class(split='valid', transform=True),
        batch_size=batch_size, shuffle=False, **kwargs)

    # 2. model

    n_class = len(train_loader.dataset.class_names)
    model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=cfg['nodeconv'])
    start_epoch = 0
github wkentaro / pytorch-fcn / examples / apc / train_fcn32s.py View on Github external
cfg, out = load_config_file(config_file)

    cuda = torch.cuda.is_available()

    batch_size = torch.cuda.device_count() * 3
    max_iter = cfg['max_iteration'] // batch_size

    torch.manual_seed(1)
    if cuda:
        torch.cuda.manual_seed(1)

    # 1. dataset

    cfg['dataset'] = cfg.get('dataset', 'v2')
    if cfg['dataset'] == 'v2':
        dataset_class = torchfcn.datasets.APC2016V2
    elif cfg['dataset'] == 'v3':
        dataset_class = torchfcn.datasets.APC2016V3
    else:
        raise ValueError('Unsupported dataset: %s' % cfg['dataset'])

    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = torch.utils.data.DataLoader(
        dataset_class(split='train', transform=True),
        batch_size=batch_size, shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(
        dataset_class(split='valid', transform=True),
        batch_size=batch_size, shuffle=False, **kwargs)

    # 2. model

    n_class = len(train_loader.dataset.class_names)
github wkentaro / pytorch-fcn / tests / models_tests / test_fcn32s.py View on Github external
def test_get_upsampling_weight():
    src = skimage.data.coffee()
    x = src.transpose(2, 0, 1)
    x = x[np.newaxis, :, :, :]
    x = torch.from_numpy(x).float()
    x = torch.autograd.Variable(x)

    in_channels = 3
    out_channels = 3
    kernel_size = 4

    m = torch.nn.ConvTranspose2d(
        in_channels, out_channels, kernel_size, stride=2, bias=False)
    m.weight.data = get_upsampling_weight(
        in_channels, out_channels, kernel_size)

    y = m(x)

    y = y.data.numpy()
    y = y[0]
    y = y.transpose(1, 2, 0)
    dst = y.astype(np.uint8)

    assert abs(src.shape[0] * 2 - dst.shape[0]) <= 2
    assert abs(src.shape[1] * 2 - dst.shape[1]) <= 2

    return src, dst
github wkentaro / pytorch-fcn / examples / voc / evaluate.py View on Github external
args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    model_file = args.model_file

    root = osp.expanduser('~/data/datasets')
    val_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.VOC2011ClassSeg(
            root, split='seg11valid', transform=True),
        batch_size=1, shuffle=False,
        num_workers=4, pin_memory=True)

    n_class = len(val_loader.dataset.class_names)

    if osp.basename(model_file).startswith('fcn32s'):
        model = torchfcn.models.FCN32s(n_class=21)
    elif osp.basename(model_file).startswith('fcn16s'):
        model = torchfcn.models.FCN16s(n_class=21)
    elif osp.basename(model_file).startswith('fcn8s'):
        if osp.basename(model_file).startswith('fcn8s-atonce'):
            model = torchfcn.models.FCN8sAtOnce(n_class=21)
        else:
            model = torchfcn.models.FCN8s(n_class=21)
    else:
        raise ValueError
    if torch.cuda.is_available():
        model = model.cuda()
    print('==> Loading %s model file: %s' %
          (model.__class__.__name__, model_file))
    model_data = torch.load(model_file)
    try:
        model.load_state_dict(model_data)
github wkentaro / pytorch-fcn / examples / voc / train_fcn32s.py View on Github external
def get_parameters(model, bias=False):
    import torch.nn as nn
    modules_skipped = (
        nn.ReLU,
        nn.MaxPool2d,
        nn.Dropout2d,
        nn.Sequential,
        torchfcn.models.FCN32s,
        torchfcn.models.FCN16s,
        torchfcn.models.FCN8s,
    )
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            if bias:
                yield m.bias
            else:
                yield m.weight
        elif isinstance(m, nn.ConvTranspose2d):
            # weight is frozen because it is just a bilinear upsampling
            if bias:
                assert m.bias is None
        elif isinstance(m, modules_skipped):
            continue
        else:
github wkentaro / pytorch-fcn / examples / voc / train_fcn32s.py View on Github external
torchfcn.datasets.VOC2011ClassSeg(
            root, split='seg11valid', transform=True),
        batch_size=1, shuffle=False, **kwargs)

    # 2. model

    model = torchfcn.models.FCN32s(n_class=21)
    start_epoch = 0
    start_iteration = 0
    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        vgg16_fcn32s = torchfcn.models.FCN32s(n_class=21)
        vgg16_fcn32s.load_state_dict(torch.load(osp.expanduser('~/data/models/torch/vgg16-fcn32s.pth')))
        model.copy_params_from_vgg16(vgg16_fcn32s, copy_fc8=False)
    if cuda:
        model = model.cuda()

    # 3. optimizer

    optim = torch.optim.SGD(
        [
            {'params': get_parameters(model, bias=False)},
            {'params': get_parameters(model, bias=True),
             'lr': cfg['lr'] * 2, 'weight_decay': 0},
        ],
        lr=cfg['lr'],
        momentum=cfg['momentum'],
        weight_decay=cfg['weight_decay'])
github wkentaro / pytorch-fcn / examples / voc / train_fcn32s.py View on Github external
# 1. dataset

    root = osp.expanduser('~/data/datasets')
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
        batch_size=1, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.VOC2011ClassSeg(
            root, split='seg11valid', transform=True),
        batch_size=1, shuffle=False, **kwargs)

    # 2. model

    model = torchfcn.models.FCN32s(n_class=21)
    start_epoch = 0
    start_iteration = 0
    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        vgg16_fcn32s = torchfcn.models.FCN32s(n_class=21)
        vgg16_fcn32s.load_state_dict(torch.load(osp.expanduser('~/data/models/torch/vgg16-fcn32s.pth')))
        model.copy_params_from_vgg16(vgg16_fcn32s, copy_fc8=False)
    if cuda:
        model = model.cuda()

    # 3. optimizer
github wkentaro / pytorch-fcn / examples / voc / evaluate.py View on Github external
root = osp.expanduser('~/data/datasets')
    val_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.VOC2011ClassSeg(
            root, split='seg11valid', transform=True),
        batch_size=1, shuffle=False,
        num_workers=4, pin_memory=True)

    n_class = len(val_loader.dataset.class_names)

    if osp.basename(model_file).startswith('fcn32s'):
        model = torchfcn.models.FCN32s(n_class=21)
    elif osp.basename(model_file).startswith('fcn16s'):
        model = torchfcn.models.FCN16s(n_class=21)
    elif osp.basename(model_file).startswith('fcn8s'):
        if osp.basename(model_file).startswith('fcn8s-atonce'):
            model = torchfcn.models.FCN8sAtOnce(n_class=21)
        else:
            model = torchfcn.models.FCN8s(n_class=21)
    else:
        raise ValueError
    if torch.cuda.is_available():
        model = model.cuda()
    print('==> Loading %s model file: %s' %
          (model.__class__.__name__, model_file))
    model_data = torch.load(model_file)
    try:
        model.load_state_dict(model_data)
    except Exception:
        model.load_state_dict(model_data['model_state_dict'])
    model.eval()

    print('==> Evaluating with VOC2011ClassSeg seg11valid')
github wkentaro / pytorch-fcn / examples / voc / train_fcn16s.py View on Github external
torchfcn.datasets.VOC2011ClassSeg(
            root, split='seg11valid', transform=True),
        batch_size=1, shuffle=False, **kwargs)

    # 2. model

    model = torchfcn.models.FCN16s(n_class=21)
    start_epoch = 0
    start_iteration = 0
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        fcn32s = torchfcn.models.FCN32s()
        state_dict = torch.load(args.pretrained_model)
        try:
            fcn32s.load_state_dict(state_dict)
        except RuntimeError:
            fcn32s.load_state_dict(state_dict['model_state_dict'])
        model.copy_params_from_fcn32s(fcn32s)
    if cuda:
        model = model.cuda()

    # 3. optimizer

    optim = torch.optim.SGD(
        [
            {'params': get_parameters(model, bias=False)},
            {'params': get_parameters(model, bias=True),
             'lr': args.lr * 2, 'weight_decay': 0},
github wkentaro / pytorch-fcn / examples / voc / train_fcn8s_atonce.py View on Github external
torchfcn.datasets.VOC2011ClassSeg(
            root, split='seg11valid', transform=True),
        batch_size=1, shuffle=False, **kwargs)

    # 2. model

    model = torchfcn.models.FCN8sAtOnce(n_class=21)
    start_epoch = 0
    start_iteration = 0
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)
    if cuda:
        model = model.cuda()

    # 3. optimizer

    optim = torch.optim.SGD(
        [
            {'params': get_parameters(model, bias=False)},
            {'params': get_parameters(model, bias=True),
             'lr': args.lr * 2, 'weight_decay': 0},
        ],
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    if args.resume: