Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
device = torch.device("cpu")
model = VideoModel(pool_spatial=args.pool_spatial,
pool_temporal=args.pool_temporal)
model.eval()
for params in model.parameters():
params.requires_grad = False
model = model.to(device)
model = nn.DataParallel(model)
transform = Compose([
ToTensor(),
Rearrange("t h w c -> c t h w"),
Resize(args.frame_size),
Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
])
# dataset = WebcamDataset(clip=32, transform=transform)
dataset = VideoDataset(args.video, clip=32, transform=transform)
loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0, shuffle=False)
features = []
with torch.no_grad():
for inputs in tqdm(loader, total=len(dataset) // args.batch_size):
inputs = inputs.to(device)
outputs = model(inputs)
device = torch.device("cpu")
model = VideoModel()
model.eval()
for params in model.parameters():
params.requires_grad = False
model = model.to(device)
model = nn.DataParallel(model)
mean, std = [0.43216, 0.394666, 0.37645], [0.22803, 0.22145, 0.216989]
transform = Compose([
ToTensor(),
Rearrange("t h w c -> c t h w"),
Resize(args.frame_size),
Normalize(mean=mean, std=std),
])
# Take first clip from video only for now.
# Could be made to run on the full video.
dataset = VideoDataset(args.video, clip=32, transform=transform)
video = next(iter(dataset))
# video = torch.rand(3, 32, 128, 128)
assert video.size()[0:2] == (3, 32)
video = rearrange(video, "c t h w -> () c t h w")
video = video.data.cpu().numpy()
std (tuple): Normalization std-dev
"""
batch_size = len(batch)
plt.tight_layout()
fig, axs = plt.subplots(
batch_size,
sample_length,
figsize=(4 * sample_length, 3 * batch_size)
)
for i, ax in enumerate(axs):
if batch_size == 1:
clip = batch[0]
else:
clip = batch[i]
clip = Rearrange("c t h w -> t c h w")(clip)
if not isinstance(ax, np.ndarray):
ax = [ax]
for j, a in enumerate(ax):
a.axis("off")
a.imshow(
np.moveaxis(
denormalize(
clip[j],
mean,
std,
).numpy(),
0,
-1,
)