How to use the torchfcn.datasets function in torchfcn

To help you get started, we’ve selected a few torchfcn examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github wkentaro / pytorch-fcn / examples / apc / train_fcn32s.py View on Github external
cuda = torch.cuda.is_available()

    batch_size = torch.cuda.device_count() * 3
    max_iter = cfg['max_iteration'] // batch_size

    torch.manual_seed(1)
    if cuda:
        torch.cuda.manual_seed(1)

    # 1. dataset

    cfg['dataset'] = cfg.get('dataset', 'v2')
    if cfg['dataset'] == 'v2':
        dataset_class = torchfcn.datasets.APC2016V2
    elif cfg['dataset'] == 'v3':
        dataset_class = torchfcn.datasets.APC2016V3
    else:
        raise ValueError('Unsupported dataset: %s' % cfg['dataset'])

    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = torch.utils.data.DataLoader(
        dataset_class(split='train', transform=True),
        batch_size=batch_size, shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(
        dataset_class(split='valid', transform=True),
        batch_size=batch_size, shuffle=False, **kwargs)

    # 2. model

    n_class = len(train_loader.dataset.class_names)
    model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=cfg['nodeconv'])
    start_epoch = 0
github wkentaro / pytorch-fcn / examples / apc / train_fcn32s.py View on Github external
cfg, out = load_config_file(config_file)

    cuda = torch.cuda.is_available()

    batch_size = torch.cuda.device_count() * 3
    max_iter = cfg['max_iteration'] // batch_size

    torch.manual_seed(1)
    if cuda:
        torch.cuda.manual_seed(1)

    # 1. dataset

    cfg['dataset'] = cfg.get('dataset', 'v2')
    if cfg['dataset'] == 'v2':
        dataset_class = torchfcn.datasets.APC2016V2
    elif cfg['dataset'] == 'v3':
        dataset_class = torchfcn.datasets.APC2016V3
    else:
        raise ValueError('Unsupported dataset: %s' % cfg['dataset'])

    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = torch.utils.data.DataLoader(
        dataset_class(split='train', transform=True),
        batch_size=batch_size, shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(
        dataset_class(split='valid', transform=True),
        batch_size=batch_size, shuffle=False, **kwargs)

    # 2. model

    n_class = len(train_loader.dataset.class_names)
github wkentaro / pytorch-fcn / examples / voc / train_fcn32s.py View on Github external
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset

    root = osp.expanduser('~/data/datasets')
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
        batch_size=1, shuffle=True, **kwargs)
    val_loader = torch.utils.data.DataLoader(
        torchfcn.datasets.VOC2011ClassSeg(
            root, split='seg11valid', transform=True),
        batch_size=1, shuffle=False, **kwargs)

    # 2. model

    model = torchfcn.models.FCN32s(n_class=21)
    start_epoch = 0
    start_iteration = 0
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        vgg16 = torchfcn.models.VGG16(pretrained=True)
        model.copy_params_from_vgg16(vgg16)