Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
cuda = torch.cuda.is_available()
batch_size = torch.cuda.device_count() * 3
max_iter = cfg['max_iteration'] // batch_size
torch.manual_seed(1)
if cuda:
torch.cuda.manual_seed(1)
# 1. dataset
cfg['dataset'] = cfg.get('dataset', 'v2')
if cfg['dataset'] == 'v2':
dataset_class = torchfcn.datasets.APC2016V2
elif cfg['dataset'] == 'v3':
dataset_class = torchfcn.datasets.APC2016V3
else:
raise ValueError('Unsupported dataset: %s' % cfg['dataset'])
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
dataset_class(split='train', transform=True),
batch_size=batch_size, shuffle=True, **kwargs)
valid_loader = torch.utils.data.DataLoader(
dataset_class(split='valid', transform=True),
batch_size=batch_size, shuffle=False, **kwargs)
# 2. model
n_class = len(train_loader.dataset.class_names)
model = torchfcn.models.FCN32s(n_class=n_class, nodeconv=cfg['nodeconv'])
start_epoch = 0
cfg, out = load_config_file(config_file)
cuda = torch.cuda.is_available()
batch_size = torch.cuda.device_count() * 3
max_iter = cfg['max_iteration'] // batch_size
torch.manual_seed(1)
if cuda:
torch.cuda.manual_seed(1)
# 1. dataset
cfg['dataset'] = cfg.get('dataset', 'v2')
if cfg['dataset'] == 'v2':
dataset_class = torchfcn.datasets.APC2016V2
elif cfg['dataset'] == 'v3':
dataset_class = torchfcn.datasets.APC2016V3
else:
raise ValueError('Unsupported dataset: %s' % cfg['dataset'])
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
dataset_class(split='train', transform=True),
batch_size=batch_size, shuffle=True, **kwargs)
valid_loader = torch.utils.data.DataLoader(
dataset_class(split='valid', transform=True),
batch_size=batch_size, shuffle=False, **kwargs)
# 2. model
n_class = len(train_loader.dataset.class_names)
def test_get_upsampling_weight():
src = skimage.data.coffee()
x = src.transpose(2, 0, 1)
x = x[np.newaxis, :, :, :]
x = torch.from_numpy(x).float()
x = torch.autograd.Variable(x)
in_channels = 3
out_channels = 3
kernel_size = 4
m = torch.nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride=2, bias=False)
m.weight.data = get_upsampling_weight(
in_channels, out_channels, kernel_size)
y = m(x)
y = y.data.numpy()
y = y[0]
y = y.transpose(1, 2, 0)
dst = y.astype(np.uint8)
assert abs(src.shape[0] * 2 - dst.shape[0]) <= 2
assert abs(src.shape[1] * 2 - dst.shape[1]) <= 2
return src, dst
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
model_file = args.model_file
root = osp.expanduser('~/data/datasets')
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False,
num_workers=4, pin_memory=True)
n_class = len(val_loader.dataset.class_names)
if osp.basename(model_file).startswith('fcn32s'):
model = torchfcn.models.FCN32s(n_class=21)
elif osp.basename(model_file).startswith('fcn16s'):
model = torchfcn.models.FCN16s(n_class=21)
elif osp.basename(model_file).startswith('fcn8s'):
if osp.basename(model_file).startswith('fcn8s-atonce'):
model = torchfcn.models.FCN8sAtOnce(n_class=21)
else:
model = torchfcn.models.FCN8s(n_class=21)
else:
raise ValueError
if torch.cuda.is_available():
model = model.cuda()
print('==> Loading %s model file: %s' %
(model.__class__.__name__, model_file))
model_data = torch.load(model_file)
try:
model.load_state_dict(model_data)
def get_parameters(model, bias=False):
import torch.nn as nn
modules_skipped = (
nn.ReLU,
nn.MaxPool2d,
nn.Dropout2d,
nn.Sequential,
torchfcn.models.FCN32s,
torchfcn.models.FCN16s,
torchfcn.models.FCN8s,
)
for m in model.modules():
if isinstance(m, nn.Conv2d):
if bias:
yield m.bias
else:
yield m.weight
elif isinstance(m, nn.ConvTranspose2d):
# weight is frozen because it is just a bilinear upsampling
if bias:
assert m.bias is None
elif isinstance(m, modules_skipped):
continue
else:
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN32s(n_class=21)
start_epoch = 0
start_iteration = 0
if resume:
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16_fcn32s = torchfcn.models.FCN32s(n_class=21)
vgg16_fcn32s.load_state_dict(torch.load(osp.expanduser('~/data/models/torch/vgg16-fcn32s.pth')))
model.copy_params_from_vgg16(vgg16_fcn32s, copy_fc8=False)
if cuda:
model = model.cuda()
# 3. optimizer
optim = torch.optim.SGD(
[
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': cfg['lr'] * 2, 'weight_decay': 0},
],
lr=cfg['lr'],
momentum=cfg['momentum'],
weight_decay=cfg['weight_decay'])
# 1. dataset
root = osp.expanduser('~/data/datasets')
kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN32s(n_class=21)
start_epoch = 0
start_iteration = 0
if resume:
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16_fcn32s = torchfcn.models.FCN32s(n_class=21)
vgg16_fcn32s.load_state_dict(torch.load(osp.expanduser('~/data/models/torch/vgg16-fcn32s.pth')))
model.copy_params_from_vgg16(vgg16_fcn32s, copy_fc8=False)
if cuda:
model = model.cuda()
# 3. optimizer
root = osp.expanduser('~/data/datasets')
val_loader = torch.utils.data.DataLoader(
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False,
num_workers=4, pin_memory=True)
n_class = len(val_loader.dataset.class_names)
if osp.basename(model_file).startswith('fcn32s'):
model = torchfcn.models.FCN32s(n_class=21)
elif osp.basename(model_file).startswith('fcn16s'):
model = torchfcn.models.FCN16s(n_class=21)
elif osp.basename(model_file).startswith('fcn8s'):
if osp.basename(model_file).startswith('fcn8s-atonce'):
model = torchfcn.models.FCN8sAtOnce(n_class=21)
else:
model = torchfcn.models.FCN8s(n_class=21)
else:
raise ValueError
if torch.cuda.is_available():
model = model.cuda()
print('==> Loading %s model file: %s' %
(model.__class__.__name__, model_file))
model_data = torch.load(model_file)
try:
model.load_state_dict(model_data)
except Exception:
model.load_state_dict(model_data['model_state_dict'])
model.eval()
print('==> Evaluating with VOC2011ClassSeg seg11valid')
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN16s(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
fcn32s = torchfcn.models.FCN32s()
state_dict = torch.load(args.pretrained_model)
try:
fcn32s.load_state_dict(state_dict)
except RuntimeError:
fcn32s.load_state_dict(state_dict['model_state_dict'])
model.copy_params_from_fcn32s(fcn32s)
if cuda:
model = model.cuda()
# 3. optimizer
optim = torch.optim.SGD(
[
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': args.lr * 2, 'weight_decay': 0},
torchfcn.datasets.VOC2011ClassSeg(
root, split='seg11valid', transform=True),
batch_size=1, shuffle=False, **kwargs)
# 2. model
model = torchfcn.models.FCN8sAtOnce(n_class=21)
start_epoch = 0
start_iteration = 0
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
start_iteration = checkpoint['iteration']
else:
vgg16 = torchfcn.models.VGG16(pretrained=True)
model.copy_params_from_vgg16(vgg16)
if cuda:
model = model.cuda()
# 3. optimizer
optim = torch.optim.SGD(
[
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True),
'lr': args.lr * 2, 'weight_decay': 0},
],
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
if args.resume: