등장 이후 NLP 분야에 획기적인 한 획을 그었던 Transformer모델을 이미지 분야에 사용하기 위해 변형된 Vision Transformer이다.
해당 코드와 설명은 아래 게시물을 참고하였다.
ResNet-18 Implementation Code¶
code & descriptoin reference : https://yhkim4504.tistory.com/5
Patch Embedding¶
이미지를 패치들로 나누어 임베딩 (class token + positional embedding 추가)
In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary
In [3]:
x = Image.open('./data/cat01.jpeg')
x = x.resize((224,224))
tf_toTensor = ToTensor()
x = tf_toTensor(x)
x = torch.unsqueeze(x,0) #배치 크기 맞춰줌
x.shape
Out[3]:
torch.Size([1, 3, 224, 224])
In [4]:
x = torch.randn(8,3,224,224)
In [5]:
P = 16
N = int(224*224/(16*16)) #196
"""
기존의 B*C*H*W의 차원을 B*N*(P*P*C)로 바꿔줘야함
einops의 rearrange 함수를 이용하여 이미지를 패치로 나누고 flatten을 한번에 수행할 수 있다.
"""
patches = rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=P, s2=P) #[1,N,P*P*c\]
In [6]:
patches.shape
Out[6]:
torch.Size([8, 196, 768])
하지만 vit에서는 위에서처럼 linear embedding이 아니라 convolutional 2D layer를 이용한 후 flatten한다
==>performance gain
In [7]:
P= 16
in_channels =3
emb_size =768
projection = nn.Sequential(
nn.Conv2d(3,emb_size, kernel_size=P, stride=P),
Rearrange('b e (h) (w) -> b (h w) e')
)
projection(x).shape
Out[7]:
torch.Size([8, 196, 768])
cls token과 positional encoding 추가
In [16]:
emb_size = 768
img_size = 224
patch_size = 16
#이미지를 패치사이즈로 나누고 flatten
proj_x = projection(x)
print("Projected X shape : ", proj_x.shape)
#cls_token과 positional encoding parameter 선언
cls_token = nn.Parameter(torch.randn(1,1,emb_size))
positions = nn.Parameter(torch.randn((img_size//patch_size)**2+1, emb_size))
print("Cls shape : ", cls_token.shape)
print("Pos shape : ", positions.shape)
#cls_token을 배치사이즈만큼 확장
batch_size = 8
cls_tokens = repeat(cls_token, '() n e -> b n e', b=batch_size)
print("Repeated Cls shape : ", cls_tokens.shape)
#cls_token과 proj_x를 concatenate
cat_x = torch.cat([cls_tokens, proj_x], dim=1)
print("Cls+proj_x : ", cat_x.shape)
#position encoding을 더해줌
cat_x += positions
print("Output : ", cat_x.shape)
Projected X shape : torch.Size([8, 196, 768]) Cls shape : torch.Size([1, 1, 768]) Pos shape : torch.Size([197, 768]) Repeated Cls shape : torch.Size([8, 1, 768]) Cls+proj_x : torch.Size([8, 197, 768]) Output : torch.Size([8, 197, 768])
위 작업을 클래스화
In [24]:
class PatchEmbedding(nn.Module):
def __init__(self, in_channels:int=3, patch_size:int=16, emb_size:int=768, img_size:int=224):
self.P = patch_size
super().__init__()
self.projection = nn.Sequential(
nn.Conv2d(in_channels,emb_size, kernel_size=patch_size, stride=patch_size),
Rearrange('b e h w -> b (h w) e')
)
self.cls_token= nn.Parameter(torch.randn(1,1,emb_size))
self.positions = nn.Parameter(torch.randn((img_size//patch_size)**2+1, emb_size))
def forward(self, x:Tensor)->Tensor:
b,_,_,_ = x.shape
x =self.projection(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
x = torch.cat([cls_tokens, x], dim=1)
x+= self.positions
return x
In [25]:
embed = PatchEmbedding()
x = embed(x)
In [26]:
x.shape
Out[26]:
torch.Size([8, 197, 768])
Multi head attention¶
Linear projection
In [27]:
emb= 768
num_heads =8
#k,q,v 입력 Linear embedding
keys= nn.Linear(emb,emb)
queries = nn.Linear(emb,emb)
values =nn.Linear(emb,emb)
Multi head
In [28]:
#Linear projection을 거친 q,k,v를 8 개의 head로 나눔
queries = rearrange(queries(x), 'b n (h d) -> b h n d', h=num_heads)
keys = rearrange(keys(x), 'b n (h d) -> b h n d', h=num_heads)
values = rearrange(values(x), 'b n (h d) -> b h n d', h=num_heads)
In [29]:
queries.shape, keys.shape, values.shape
Out[29]:
(torch.Size([8, 8, 197, 96]), torch.Size([8, 8, 197, 96]), torch.Size([8, 8, 197, 96]))
Scaled dot product attention
In [30]:
# queries * keys
# Q.matmul(K.T)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
print("energy : ",energy.shape)
#Get attention score
scaling = emb_size**(1/2)
att = F.softmax(energy, dim=-1)/scaling
print("att : ", att.shape)
#Attention score * Values
out = torch.einsum('bhal, bhlv -> bhav',att, values)
print("out : ", out.shape)
#Rearrange to emb_size (concatenate)
out = rearrange(out, "b h n d -> b n (h d)")
print("out2 : ", out.shape)
energy : torch.Size([8, 8, 197, 197]) att : torch.Size([8, 8, 197, 197]) out : torch.Size([8, 8, 197, 96]) out2 : torch.Size([8, 197, 768])
In [31]:
class MultiHeadAttention(nn.Module) :
def __init__(self, emb_size:int=768, num_heads:int=8, dropout:float=0):
super().__init__()
self.emb_size= emb_size
self.num_heads = num_heads
#Fuse the queries, keys, values in one matrix
self.qkv = nn.Linear(emb_size, emb_size*3)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x:Tensor, mask:Tensor = None)-> Tensor:
#split keys, queries, and vlaues in num_heads
qkv = rearrange(self.qkv(x), "b n (h d qkv) -> (qkv) b h n d", h = self.num_heads, qkv=3)
queries, keys, values = qkv[0], qkv[1], qkv[2]
#sum up over the last axis
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) #batch, num_heads, query_len, key_len
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size**(1/2)
att = F.softmax(energy, dim =-1) / scaling
att = self.att_drop(att)
# sum up over the third axis
out = torch.einsum('bhal, bhlv -> bhav', att, values)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.projection(out)
return out
Residual Block¶
In [32]:
class ResidualAdd(nn.Module):
def __init__(self,fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x+= res
return x
Feed Forward MLP¶
- Linear - GELU - Dropout - Linear
- 첫번째 layer에서 expansion을 곱해준 만큼 임베딩 사이즈 확장
- 두번째 layer에서 다시 원래의 emb_size로 축소
In [33]:
class FeedForwardBlock(nn.Sequential):
def __init__(self,emb_size:int, expansion : int=4, drop_p :float = 0.):
super().__init__(
nn.Linear(emb_size, expansion* emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion*emb_size, emb_size))
Transformer Encoder Block¶
In [45]:
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size:int=768,
drop_p : float=0,
forward_expansion : int=4,
forward_drop_p : float=0.,
**kwargs):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, **kwargs),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion,
drop_p = forward_drop_p),
)
))
In [46]:
x = torch.randn(8,3,224,224)
patches_embedded = PatchEmbedding()(x)
TransformerEncoderBlock()(patches_embedded).shape
Out[46]:
torch.Size([8, 197, 768])
Building Block¶
In [48]:
class TransformerEncoder(nn.Sequential):
def __init__(self, depth:int =12, **kwargs):
super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])
- [TransformerEncoderBlock(**kwargs)for ) in range(depth)] 에서 앞에 이 붙은 이유는 인자를 리스트형식으로 보내는게 아니라 각각 나눠서 보내야 하기 때문
Head¶
In [52]:
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size:int =768, n_classes: int=10):
super().__init__(
Reduce('b n e -> b e', reduction ='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes))
Summary¶
In [55]:
class ViT(nn.Sequential):
def __init__(self,
in_channels: int=3,
patch_size : int=16,
emb_size :int=768,
img_size:int=12,
depth: int=12,
n_classes:int =10,
**kwargs):
super().__init__(
PatchEmbedding(in_channels, patch_size, emb_size, img_size),
TransformerEncoder(depth, emb_size=emb_size, **kwargs),
ClassificationHead(emb_size, n_classes))
summary(ViT(), (3,224,224), device='cpu')
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 768, 14, 14] 590,592 Rearrange-2 [-1, 196, 768] 0 PatchEmbedding-3 [-1, 197, 768] 0 LayerNorm-4 [-1, 197, 768] 1,536 Linear-5 [-1, 197, 2304] 1,771,776 Dropout-6 [-1, 8, 197, 197] 0 Linear-7 [-1, 197, 768] 590,592 MultiHeadAttention-8 [-1, 197, 768] 0 Dropout-9 [-1, 197, 768] 0 ResidualAdd-10 [-1, 197, 768] 0 LayerNorm-11 [-1, 197, 768] 1,536 Linear-12 [-1, 197, 3072] 2,362,368 GELU-13 [-1, 197, 3072] 0 Dropout-14 [-1, 197, 3072] 0 Linear-15 [-1, 197, 768] 2,360,064 ResidualAdd-16 [-1, 197, 768] 0 LayerNorm-17 [-1, 197, 768] 1,536 Linear-18 [-1, 197, 2304] 1,771,776 Dropout-19 [-1, 8, 197, 197] 0 Linear-20 [-1, 197, 768] 590,592 MultiHeadAttention-21 [-1, 197, 768] 0 Dropout-22 [-1, 197, 768] 0 ResidualAdd-23 [-1, 197, 768] 0 LayerNorm-24 [-1, 197, 768] 1,536 Linear-25 [-1, 197, 3072] 2,362,368 GELU-26 [-1, 197, 3072] 0 Dropout-27 [-1, 197, 3072] 0 Linear-28 [-1, 197, 768] 2,360,064 ResidualAdd-29 [-1, 197, 768] 0 LayerNorm-30 [-1, 197, 768] 1,536 Linear-31 [-1, 197, 2304] 1,771,776 Dropout-32 [-1, 8, 197, 197] 0 Linear-33 [-1, 197, 768] 590,592 MultiHeadAttention-34 [-1, 197, 768] 0 Dropout-35 [-1, 197, 768] 0 ResidualAdd-36 [-1, 197, 768] 0 LayerNorm-37 [-1, 197, 768] 1,536 Linear-38 [-1, 197, 3072] 2,362,368 GELU-39 [-1, 197, 3072] 0 Dropout-40 [-1, 197, 3072] 0 Linear-41 [-1, 197, 768] 2,360,064 ResidualAdd-42 [-1, 197, 768] 0 LayerNorm-43 [-1, 197, 768] 1,536 Linear-44 [-1, 197, 2304] 1,771,776 Dropout-45 [-1, 8, 197, 197] 0 Linear-46 [-1, 197, 768] 590,592 MultiHeadAttention-47 [-1, 197, 768] 0 Dropout-48 [-1, 197, 768] 0 ResidualAdd-49 [-1, 197, 768] 0 LayerNorm-50 [-1, 197, 768] 1,536 Linear-51 [-1, 197, 3072] 2,362,368 GELU-52 [-1, 197, 3072] 0 Dropout-53 [-1, 197, 3072] 0 Linear-54 [-1, 197, 768] 2,360,064 ResidualAdd-55 [-1, 197, 768] 0 LayerNorm-56 [-1, 197, 768] 1,536 Linear-57 [-1, 197, 2304] 1,771,776 Dropout-58 [-1, 8, 197, 197] 0 Linear-59 [-1, 197, 768] 590,592 MultiHeadAttention-60 [-1, 197, 768] 0 Dropout-61 [-1, 197, 768] 0 ResidualAdd-62 [-1, 197, 768] 0 LayerNorm-63 [-1, 197, 768] 1,536 Linear-64 [-1, 197, 3072] 2,362,368 GELU-65 [-1, 197, 3072] 0 Dropout-66 [-1, 197, 3072] 0 Linear-67 [-1, 197, 768] 2,360,064 ResidualAdd-68 [-1, 197, 768] 0 LayerNorm-69 [-1, 197, 768] 1,536 Linear-70 [-1, 197, 2304] 1,771,776 Dropout-71 [-1, 8, 197, 197] 0 Linear-72 [-1, 197, 768] 590,592 MultiHeadAttention-73 [-1, 197, 768] 0 Dropout-74 [-1, 197, 768] 0 ResidualAdd-75 [-1, 197, 768] 0 LayerNorm-76 [-1, 197, 768] 1,536 Linear-77 [-1, 197, 3072] 2,362,368 GELU-78 [-1, 197, 3072] 0 Dropout-79 [-1, 197, 3072] 0 Linear-80 [-1, 197, 768] 2,360,064 ResidualAdd-81 [-1, 197, 768] 0 LayerNorm-82 [-1, 197, 768] 1,536 Linear-83 [-1, 197, 2304] 1,771,776 Dropout-84 [-1, 8, 197, 197] 0 Linear-85 [-1, 197, 768] 590,592 MultiHeadAttention-86 [-1, 197, 768] 0 Dropout-87 [-1, 197, 768] 0 ResidualAdd-88 [-1, 197, 768] 0 LayerNorm-89 [-1, 197, 768] 1,536 Linear-90 [-1, 197, 3072] 2,362,368 GELU-91 [-1, 197, 3072] 0 Dropout-92 [-1, 197, 3072] 0 Linear-93 [-1, 197, 768] 2,360,064 ResidualAdd-94 [-1, 197, 768] 0 LayerNorm-95 [-1, 197, 768] 1,536 Linear-96 [-1, 197, 2304] 1,771,776 Dropout-97 [-1, 8, 197, 197] 0 Linear-98 [-1, 197, 768] 590,592 MultiHeadAttention-99 [-1, 197, 768] 0 Dropout-100 [-1, 197, 768] 0 ResidualAdd-101 [-1, 197, 768] 0 LayerNorm-102 [-1, 197, 768] 1,536 Linear-103 [-1, 197, 3072] 2,362,368 GELU-104 [-1, 197, 3072] 0 Dropout-105 [-1, 197, 3072] 0 Linear-106 [-1, 197, 768] 2,360,064 ResidualAdd-107 [-1, 197, 768] 0 LayerNorm-108 [-1, 197, 768] 1,536 Linear-109 [-1, 197, 2304] 1,771,776 Dropout-110 [-1, 8, 197, 197] 0 Linear-111 [-1, 197, 768] 590,592 MultiHeadAttention-112 [-1, 197, 768] 0 Dropout-113 [-1, 197, 768] 0 ResidualAdd-114 [-1, 197, 768] 0 LayerNorm-115 [-1, 197, 768] 1,536 Linear-116 [-1, 197, 3072] 2,362,368 GELU-117 [-1, 197, 3072] 0 Dropout-118 [-1, 197, 3072] 0 Linear-119 [-1, 197, 768] 2,360,064 ResidualAdd-120 [-1, 197, 768] 0 LayerNorm-121 [-1, 197, 768] 1,536 Linear-122 [-1, 197, 2304] 1,771,776 Dropout-123 [-1, 8, 197, 197] 0 Linear-124 [-1, 197, 768] 590,592 MultiHeadAttention-125 [-1, 197, 768] 0 Dropout-126 [-1, 197, 768] 0 ResidualAdd-127 [-1, 197, 768] 0 LayerNorm-128 [-1, 197, 768] 1,536 Linear-129 [-1, 197, 3072] 2,362,368 GELU-130 [-1, 197, 3072] 0 Dropout-131 [-1, 197, 3072] 0 Linear-132 [-1, 197, 768] 2,360,064 ResidualAdd-133 [-1, 197, 768] 0 LayerNorm-134 [-1, 197, 768] 1,536 Linear-135 [-1, 197, 2304] 1,771,776 Dropout-136 [-1, 8, 197, 197] 0 Linear-137 [-1, 197, 768] 590,592 MultiHeadAttention-138 [-1, 197, 768] 0 Dropout-139 [-1, 197, 768] 0 ResidualAdd-140 [-1, 197, 768] 0 LayerNorm-141 [-1, 197, 768] 1,536 Linear-142 [-1, 197, 3072] 2,362,368 GELU-143 [-1, 197, 3072] 0 Dropout-144 [-1, 197, 3072] 0 Linear-145 [-1, 197, 768] 2,360,064 ResidualAdd-146 [-1, 197, 768] 0 LayerNorm-147 [-1, 197, 768] 1,536 Linear-148 [-1, 197, 2304] 1,771,776 Dropout-149 [-1, 8, 197, 197] 0 Linear-150 [-1, 197, 768] 590,592 MultiHeadAttention-151 [-1, 197, 768] 0 Dropout-152 [-1, 197, 768] 0 ResidualAdd-153 [-1, 197, 768] 0 LayerNorm-154 [-1, 197, 768] 1,536 Linear-155 [-1, 197, 3072] 2,362,368 GELU-156 [-1, 197, 3072] 0 Dropout-157 [-1, 197, 3072] 0 Linear-158 [-1, 197, 768] 2,360,064 ResidualAdd-159 [-1, 197, 768] 0 Reduce-160 [-1, 768] 0 LayerNorm-161 [-1, 768] 1,536 Linear-162 [-1, 10] 7,690 ================================================================ Total params: 85,654,282 Trainable params: 85,654,282 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 350.47 Params size (MB): 326.75 Estimated Total Size (MB): 677.79 ----------------------------------------------------------------