import math
import numpy as np
import torch
import torch.nn as nn
import fastcore.all as fc
from PIL import Image
from functools import partial
from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip, Compose, ToTensor, ToPILImage
让我们创建一个大小为 224x224 且补丁大小为 16 的图像
img_size = 224
patch_size = 32
加载数据
imgs = fc.L(fc.Path("coco/val2017/").glob("*.jpg"))
imgs #(#5000) [Path('coco/val2017/000000182611.jpg'),Path('coco/val2017/000000335177.jpg'),Path('coco/val2017/000000278705.jpg'),Path('coco/val2017/000000463618.jpg'),Path('coco/val2017/000000568981.jpg'),Path('coco/val2017/000000092416.jpg'),Path('coco/val2017/000000173830.jpg'),Path('coco/val2017/000000476215.jpg'),Path('coco/val2017/000000479126.jpg'),Path('coco/val2017/000000570664.jpg')...]
转换
def transforms(img_size):
return Compose([RandomResizedCrop(size=img_size, scale=[0.4, 1], ratio=[0.75, 1.33], interpolation=2),
RandomHorizontalFlip(p=0.5),
ToTensor()])
def load_img(img_loc, transforms):
img = Image.open(img_loc)
return transforms(img)
load_img = partial(load_img, transforms=transforms(img_size=img_size))
img = load_img(imgs[1])
img.shape #torch.Size([3, 224, 224])
创建图像补丁
imgp = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size).permute((0, 3, 4, 1, 2)).flatten(3).permute((3, 0, 1, 2))
imgp.shape #torch.Size([49, 3, 32, 32])
fig, ax = plt.subplots(figsize=(4, 4), nrows=7, ncols=7)
for n, i in enumerate(imgp):
ax.flat[n].imshow(ToPILImage()(i))
ax.flat[n].axis("off")
plt.show()
创建屏蔽标记
tokens = imgp.shape[0]
mask_ratio = 0.75
mask_count = int(tokens* mask_ratio)
tokens, mask_count #(49, 36)
mask_idx = torch.randperm(tokens)[:mask_count]
mask = torch.zeros(tokens).long()
mask[mask_idx] = 1
mask
#tensor([1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1,
# 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1,
# 0])
fig, ax = plt.subplots(figsize=(4, 4), nrows=7, ncols=7)
for n, i in enumerate(imgp):
if mask[n] == 1:
i = torch.zeros(3, 32, 32)
ax.flat[n].imshow(ToPILImage()(i))
ax.flat[n].axis("off")
plt.show()
为每个非屏蔽标记创建嵌入向量。
input_tokens = imgp[~mask.bool(), ...].flatten(1)
input_tokens.shape
imgp[~mask.bool(), ...].shape
https://embed.notionlytics.com/wt/ZXlKM2IzSnJjM0JoWTJWVWNtRmphMlZ5U1dRaU9pSlhiRWhvWlV4VVQxbHNjMlZYV2tKbU9URndaU0lzSW5CaFoyVkpaQ0k2SWpFd09ERmhaVGRpT1dFek1qZ3dNekZoT0RGbVptSTVORGsxTkdWbE16QTFJbjA9