How to use the lpips.networks_basic.PNet function in lpips

To help you get started, we’ve selected a few lpips examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github revbucket / mister_ed / lpips / dist_model.py View on Github external
self.net = net
        self.use_gpu = use_gpu
        if self.use_gpu:
            self.map_location = None
        else:
           self.map_location = lambda storage, loc: storage
        self.model_name = '%s [%s]'%(model,net)
        if(self.model == 'net-lin'): # pretrained net + linear layer
            self.net = networks.PNetLin(use_gpu=use_gpu,pnet_type=net,use_dropout=True)

            weight_path =  os.path.join(os.path.dirname(__file__), 'weights', '%s.pth' % net)

            self.net.load_state_dict(torch.load(weight_path, 
                                     map_location=self.map_location))
        elif(self.model=='net'): # pretrained network
            self.net = networks.PNet(use_gpu=use_gpu,pnet_type=net)
            self.is_fake_net = True
        elif(self.model in ['L2','l2']):
            self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
            self.model_name = 'L2'
        elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
            self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
            self.model_name = 'SSIM'
        else:
            raise ValueError("Model [%s] not recognized." % self.model)

        self.parameters = list(self.net.parameters())
        self.net.eval()

        if(printNet):
            print('---------- Networks initialized -------------')
            networks.print_network(self.net)
github revbucket / mister_ed / lpips / networks_basic.py View on Github external
def __init__(self, pnet_type='vgg', pnet_rand=False, use_gpu=True):
        super(PNet, self).__init__()

        self.use_gpu = use_gpu

        self.pnet_type = pnet_type
        self.pnet_rand = pnet_rand

        self.shift = torch.autograd.Variable(torch.Tensor([-.030, -.088, -.188]).view(1,3,1,1))
        self.scale = torch.autograd.Variable(torch.Tensor([.458, .448, .450]).view(1,3,1,1))

        if(self.pnet_type in ['vgg','vgg16']):
            self.net = pn.vgg16(pretrained=not self.pnet_rand,requires_grad=False)
        elif(self.pnet_type=='alex'):
            self.net = pn.alexnet(pretrained=not self.pnet_rand,requires_grad=False)
        elif(self.pnet_type[:-2]=='resnet'):
            self.net = pn.resnet(pretrained=not self.pnet_rand,requires_grad=False, num=int(self.pnet_type[-2:]))
        elif(self.pnet_type=='squeeze'):