Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
model_file = args.model_file
root = osp.expanduser('~/data/datasets')
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False,
num_workers=4, pin_memory=True)
n_class = len(val_loader.dataset.class_names)
if osp.basename(model_file).startswith('fcn32s'):
model = torchfcn.models.FCN32s(n_class=21)
elif osp.basename(model_file).startswith('fcn16s'):
model = torchfcn.models.FCN16s(n_class=21)
elif osp.basename(model_file).startswith('fcn8s'):
if osp.basename(model_file).startswith('fcn8s-atonce'):
model = torchfcn.models.FCN8sAtOnce(n_class=21)
else:
model = torchfcn.models.FCN8s(n_class=21)
else:
raise ValueError
if torch.cuda.is_available():
model = model.cuda()
print('==> Loading %s model file: %s' %
(model.__class__.__name__, model_file))
model_data = torch.load(model_file)
try:
model.load_state_dict(model_data)
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN32s(n_class=21)
start_epoch = 0
start_iteration = 0
if resume:
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16_fcn32s = torchfcn.models.FCN32s(n_class=21)
vgg16_fcn32s.load_state_dict(torch.load(osp.expanduser('~/data/models/torch/vgg16-fcn32s.pth')))
model.copy_params_from_vgg16(vgg16_fcn32s, copy_fc8=False)
if cuda:
model = model.cuda()
# 3. optimizer
optim = torch.optim.SGD(
[
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': cfg['lr'] * 2, 'weight_decay': 0},
],
lr=cfg['lr'],
momentum=cfg['momentum'],
weight_decay=cfg['weight_decay'])
# 1. dataset
root = osp.expanduser('~/data/datasets')
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN32s(n_class=21)
start_epoch = 0
start_iteration = 0
if resume:
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16_fcn32s = torchfcn.models.FCN32s(n_class=21)
vgg16_fcn32s.load_state_dict(torch.load(osp.expanduser('~/data/models/torch/vgg16-fcn32s.pth')))
model.copy_params_from_vgg16(vgg16_fcn32s, copy_fc8=False)
if cuda:
model = model.cuda()
# 3. optimizer
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN16s(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
fcn32s = torchfcn.models.FCN32s()
state_dict = torch.load(args.pretrained_model)
try:
fcn32s.load_state_dict(state_dict)
except RuntimeError:
fcn32s.load_state_dict(state_dict['model_state_dict'])
model.copy_params_from_fcn32s(fcn32s)
if cuda:
model = model.cuda()
# 3. optimizer
optim = torch.optim.SGD(
[
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': args.lr * 2, 'weight_decay': 0},
dataset_class = torchfcn.datasets.APC2016V3
else:
raise ValueError('Unsupported dataset: %s' % cfg['dataset'])
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
dataset_class(split='train', transform=True),
batch_size=batch_size, shuffle=True, **kwargs)
valid_loader = torch.utils.data.DataLoader(
dataset_class(split='valid', transform=True),
batch_size=batch_size, shuffle=False, **kwargs)
# 2. model
n_class = len(train_loader.dataset.class_names)
model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=cfg['nodeconv'])
start_epoch = 0
if resume:
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
else:
vgg16 = torchfcn.models.VGG16(pretrained=True)
model.copy_params_from_vgg16(vgg16, copy_fc8=False, init_upscore=False)
if cuda:
if torch.cuda.device_count() == 1:
model = model.cuda()
else:
model = torch.nn.DataParallel(model).cuda()
# 3. optimizer
# 1. dataset
root = osp.expanduser('~/data/datasets')
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN32s(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16 = torchfcn.models.VGG16(pretrained=True)
model.copy_params_from_vgg16(vgg16)
if cuda:
model = model.cuda()
# 3. optimizer
optim = torch.optim.SGD(