How to use the torchvision.transforms.CenterCrop function in torchvision

To help you get started, we’ve selected a few torchvision examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github Lornatang / PyTorch / official / gan / wgan_gp / gan.py View on Github external
if opt.dataset in ['imagenet', 'folder', 'lfw']:
  # folder dataset
  dataset = dset.ImageFolder(root=opt.dataroot,
                             transform=transforms.Compose([
                               transforms.Resize(opt.imageSize),
                               transforms.CenterCrop(opt.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                             ]))
  nc = 3
elif opt.dataset == 'lsun':
  dataset = dset.LSUN(root=opt.dataroot, classes=['bedroom_train'],
                      transform=transforms.Compose([
                        transforms.Resize(opt.imageSize),
                        transforms.CenterCrop(opt.imageSize),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                      ]))
  nc = 3
elif opt.dataset == 'cifar10':
  dataset = dset.CIFAR10(root=opt.dataroot, download=True,
                         transform=transforms.Compose([
                           transforms.Resize(opt.imageSize),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                         ]))
  nc = 3
elif opt.dataset == 'cifar100':
  dataset = dset.CIFAR100(root=opt.dataroot, download=True,
                          transform=transforms.Compose([
                            transforms.Resize(opt.imageSize),
github ricky40403 / DSQ / train.py View on Github external
if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir,
            transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args, 0)
        return   
    
    
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)
github xternalz / SDPoint / main.py View on Github external
if args.distributed:
		train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
	else:
		train_sampler = None

	if args.evaluate:
		args.batch_size = args.val_batch_size

	train_loader = torch.utils.data.DataLoader(
		train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
		num_workers=args.workers, pin_memory=True, sampler=train_sampler)

	val_loader = torch.utils.data.DataLoader(
		datasets.ImageFolder(valdir, transforms.Compose([
			transforms.Resize(256),
			transforms.CenterCrop(224),
			transforms.ToTensor(),
			normalize,
		])),
		batch_size=args.batch_size, shuffle=False,
		num_workers=args.workers, pin_memory=True)

	if args.evaluate:
		model.eval()
		val_results_file = open(args.val_results_path, 'w')
		val_results_file.write('blockID\tratio\tflops\ttop1-acc\ttop5-acc\t\n')
		for i in [-1] + [model.module.blockID] + list(range(model.module.blockID)):
			for r in [0.5, 0.75]:
				model_flops = flops.calculate(model, i, r)
				top1, top5 = validate(train_loader, val_loader, model, criterion, i, r)
				val_results_file.write('{0}\t{1}\t{2}\t{top1:.3f}\t{top5:.3f}\n'.format(
										i if i>-1 else 'nil', r if i>-1 else 'nil',
github zhaoyanglijoey / Poem-From-Image / train_sentiment.py View on Github external
def __init__(self, train_data, test_data, img_dir, batchsize, load_model, device):
        self.device = device
        self.train_data = train_data
        self.test_data = test_data
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])

        self.test_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])

        self.train_set = VisualSentimentDataset(self.train_data, img_dir,
                                               transform=self.train_transform)
        self.train_loader = DataLoader(self.train_set, batch_size=batchsize, shuffle=True, num_workers=4)

        self.test_set = VisualSentimentDataset(self.test_data, img_dir,
                                              transform=self.test_transform)
        self.test_loader = DataLoader(self.test_set, batch_size=batchsize, num_workers=4)

        self.model = Res50_sentiment()
        self.model = DataParallel(self.model)
        if load_model:
            logger.info('load model from '+ load_model)
            self.model.load_state_dict(torch.load(load_model))
github hyunjaelee410 / style-based-recalibration-module / imagenet.py View on Github external
std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(traindir, transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.train_batch, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # create model
    model = resnet(
                depth=args.depth,
                recalibration_type=args.recalibration_type,
            )

    model = torch.nn.DataParallel(model).cuda()

    cudnn.benchmark = True
    print(model)
github rwightman / pytorch-nips2017-adversarial / python / validate_classifier.py View on Github external
scale_size = int(math.floor(args.img_size / 0.875))
    if 'inception' in args.model:
        normalize = LeNormalize()
        scale_size = args.img_size
    elif 'dpn' in args.model:
        if args.img_size != 224:
            scale_size = args.img_size
        normalize = transforms.Normalize(mean=[124/255, 117/255, 104/255], std=[1/(.0167*255)]*3)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    dataset = datasets.ImageFolder(
        args.data,
        transforms.Compose([
            transforms.Scale(scale_size, Image.BICUBIC),
            transforms.CenterCrop(args.img_size),
            transforms.ToTensor(),
            normalize,
        ]))

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()
github hahnyuan / nn_tools / Datasets / imagenet.py View on Github external
def trans_val_data(dir):
            tensor = transforms.Compose([
                transforms.Scale(256),
                transforms.CenterCrop(224),
                transforms.ToTensor()
            ])(dir)
            tensor=(tensor.numpy()*255).astype(np.uint8)
            return tensor
github potterhsu / SVHNClassifier-PyTorch / infer.py View on Github external
def _infer(path_to_checkpoint_file, path_to_input_image):
    model = Model()
    model.restore(path_to_checkpoint_file)
    model.cuda()

    with torch.no_grad():
        transform = transforms.Compose([
            transforms.Resize([64, 64]),
            transforms.CenterCrop([54, 54]),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        image = Image.open(path_to_input_image)
        image = image.convert('RGB')
        image = transform(image)
        images = image.unsqueeze(dim=0).cuda()

        length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits = model.eval()(images)

        length_prediction = length_logits.max(1)[1]
        digit1_prediction = digit1_logits.max(1)[1]
        digit2_prediction = digit2_logits.max(1)[1]
        digit3_prediction = digit3_logits.max(1)[1]
        digit4_prediction = digit4_logits.max(1)[1]
github rwightman / pytorch-image-models / models / transforms.py View on Github external
scale_size = int(math.floor(img_size / crop_pct))
        normalize = transforms.Normalize(
            mean=IMAGENET_DPN_MEAN,
            std=IMAGENET_DPN_STD)
    elif 'inception' in model_name:
        scale_size = int(math.floor(img_size / crop_pct))
        normalize = LeNormalize()
    else:
        scale_size = int(math.floor(img_size / crop_pct))
        normalize = transforms.Normalize(
            mean=IMAGENET_DEFAULT_MEAN,
            std=IMAGENET_DEFAULT_STD)

    return transforms.Compose([
        transforms.Resize(scale_size, Image.BICUBIC),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        normalize])
github diux-dev / cluster / pytorch / training / train_imagenet_nv_4gpu.py View on Github external
def create_validation_set(valdir, batch_size, target_size, use_ar):
    idx_ar_sorted = sort_ar(valdir)
    idx_sorted, _ = zip(*idx_ar_sorted)
    idx2ar = map_idx2ar(idx_ar_sorted, batch_size)
    
    if use_ar:
        ar_tfms = [transforms.Resize(int(target_size*1.14)), CropArTfm(idx2ar, target_size)]
        val_dataset = ValDataset(valdir, transform=ar_tfms)
        val_sampler = ValDistSampler(idx_sorted, batch_size=batch_size)
        return val_dataset, val_sampler
    
    val_tfms = [transforms.Resize(int(args.sz*1.14)), transforms.CenterCrop(args.sz)]
    val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))
    val_sampler = ValDistSampler(list(range(len(val_dataset))), batch_size=batch_size)
    return val_dataset, val_sampler