MAE 源代码理解part1:调试理解法
git官⽅链接:
下了MAE代码 完全看不懂 我要⼀步⼀步来 把这篇代码给全部理解了 。我⾃⼰觉得看⼤神代码很有⽤。 这篇⽂章当笔记⽤。⼀,跑⽰例:
怎么说 ⼀上来肯定是把demo⾥的代码拿出来跑⼀跑。但是会遇到问题。 下⾯时demo的代码。 第⼀个问题是
说函数没这个参数 那很简单 到位置 删掉就⾏ 为啥我敢删 就是因为他的值是 None ,直接删就⾏
第⼆个问题是 我⼀开始把
这三个模型当成了预训练模型 , 下⾯左就是得到的结果 这啥啊 还原了个寂寞 。 想了半天kaiming是不是错了 ,再想了半天kaiming怎么
会错 ,才发现预训练模型藏在链接⾥。下⾯这三个只是他开始训练时使⽤的预训练模型。
链接在demo⾥到 两个large的 模型参数如下 跑的结果如上右 对嘛
复现结束了 (bushi)
终于把演⽰跑通了。
2 画图
调试这个⽅法可太神了,我们上⾯跑通了demo 就让我们跟着demo⼀览模型全貌吧!
这段 获取图像并且归⼀化 然后⽤plt画出来 这⾥是先归⼀化 画图时再返回回来。
(吐槽 : 我不理解 为什么要先归⼀化 再回来 再画图 多此⼀举? 我直接show img 不⾹吗)
TypeError: __init__() got an unexpected keyword argument 'qk_scale'
1
dl.fbaipublicfiles/mae/visualize/mae_visualize_vit_large.pth 2 3dl.fbaipublicfiles/mae/visualize/mae_visualize_vit_large_ganloss.pth
3 载⼊模型
3.1准备模型
会进⼊准备模型的函数⾥
对于第⼀局 getattr(models_mae,arch): 是取models_mae模块⾥的arch ⽽这个arch是什么 下图可以看到是⼀个函数 ⽽且是⼀个没带括号的函数 (我不理解 ) 所以get后要补⼀个括号
然后我们进⼊这个函数, 可以看到这个函数了 哦~ 是⼀个获取模型的函数 ⼤ 中⼩模型有三个不同的函数 不同函数的参数不⼀样罢了。
1
# load an image 2
img_url = 'user-images.githubusercontent/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_0003
# img_url = 'user-images.githubusercontent/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012
4
img = Image.(img_url, stream=True).raw)5
#raw 是⼀种格式 stream 是确定能下再下。(⽐如会事先确定内存)6
img = size((224, 224))7
img = np.array(img) / 255.8
9
assert img.shape == (224, 224, 3)10
11
# normalize by ImageNet mean and std 12
img = img - imagenet_mean 13
img = img / imagenet_std 14
15
show_sor(img))17
18
def show_image(image, title=''):19
# image is [H, W, 3]20
assert image.shape[2] == 321
plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())22
#刚才归⼀化了 现在返回 记得clip 防⽌越界 int 防⽌⼩数 因为像素都是整数 imshow 竟然可以读张量23
plt.title(title, fontsize=16)24
plt.show()25 plt.axis('off')26 return 1
chkpt_dir = 'model_save/mae_visualize_vit_large.pth'2model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')3print('Model loaded.')
1
def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):2
# build model 3
model = getattr(models_mae, arch)()4
# load model 5
checkpoint = torch.load(chkpt_dir, map_location='cpu')6
msg = model.load_state_dict(checkpoint['model'], strict=False)7
print(msg)8 9 return model
然后就是⼀个⼤⼯程了 我们进这个模型内部看⼀看。
3.2.1_模型内部
模型代码太⼤了 我就不贴整个的了 我⼀部分⼀部分的贴。
3.2.1.1 编码器模块
这个编码 来⾃于VIT的编码, 然⽽我并没有看过VIT的代码是什么样⼦的 。这篇⾥先不写 ,等到下⼀篇⽂章 我就遍历进这个编码函数⾥ 看看是什么东西。 我们就记住 有⼀个编码的函数 似乎是吧图⽚ 变成⼀串特征码
cls令牌 加⼊ 位置编码加⼊ nn.patameter这个函数 就是将⼀个不可训练的张量或者矩阵 转换为模型内可以训练的参数。 (想写⼀个要训练的参数 ⼜不是官⽅的那些层 ,终于知道⽅法啦)。cls_token⼤⼩是 (1,1,1024) 位置编码是 (1,197,1024) 为啥是197呢应该是为了跟嵌⼊cls后的编码⼤⼩保持⼀致 然后可以cat 我猜。
这⾥的 block 就是VIT⾥的那个block 这个block也等到VIT代码时再讲
这⾥有⼏个他们⽤的⼩trick
nn.ModuleList 其实就是⼀个列表 把⼀些块放在这个列表⾥ 与普通列表不同的是 普通的列表不会得到训练 。 这⾥就是放了24个⾃注意⼒块 每个块有12个头 。以上就是编码器⽤到的模块。
1
dels.vision_transformer import PatchEmbed, Block 2
3
4
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)5
#patch_size 应该是⼀个图⽚分出来的 ⼀张有多⼤ inchans ⼀般都是3 图⽚层数嘛6
# embed——dim 这个是编出来的特征维度 10247
8
9
num_patches = self.patch_embed.num_patches 10
##num_pathches ⼤⼩是x*y 就是图⽚分成x*y 份num_patches = (224/patch_size)**2 = 14 **2 = 19611
1
self.cls_token = nn.s(1, 1, embed_dim))2self.pos_embed = nn.s(1, num_patches + 1, embed_dim), 3 requires_grad=False) # fixed sin-cos embedding
1
self.blocks = nn.ModuleList([2
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)3 for i in range(depth)])4 = norm_layer(embed_dim)
1
nn.LayerNorm #这个表⽰在channel 上做归⼀化 2
nn.batchNorm #这个是在batch 上归⼀化3
DropPath # 这个也是⼀种与dropout 不同的 drop ⽅法4nn.GELU #⼀种激活函数
3.2.1.2 解码模块
下⾯是解码器。
1 self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
2 # ⼀个fc层 1024到512
3
4 self.mask_token = nn.s(1, 1, decoder_embed_dim))
5 #⼀个mask编码(1,1,512)
6 self.decoder_pos_embed = nn.s(1, num_patches + 1,
7 decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
8 #⼀个位置编码⽽且不训练(1,197,512)为什么不训练啊?
1 self.decoder_blocks = nn.ModuleList([
2 Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
3 for i in range(decoder_depth)])
4
5 self.decoder_norm = norm_layer(decoder_embed_dim)
6
7 self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
8 #预测层 512 到 256*3 (这个也不到224*224*3啊)
解码器的注意⼒层只有8层 但也是12头的 输⼊是512维
3.2.1.3 初始化模块
3.2.1.3.1 位置编码
1 _pix_loss = norm_pix_loss
2
3 self.initialize_weights()
第⼀个的值是false 等会看看有啥⽤ 第⼆个是⼀个函数 我们进去看看 。
1 pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
2 self.pos_py_(torch.from_numpy(pos_embed).float().unsqueeze(0))
初始化 第⼀步 是⼀个位置编码函数 ,我们进⼊这个编码函数去看
1def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
2 #embed_dim = 1024 是位置的最后⼀维 gridSize是每个⼩patch的长宽也就是14
3
4decoder
5 grid_h = np.arange(grid_size, dtype=np.float32)
6 grid_w = np.arange(grid_size, dtype=np.float32)
7 #⽣成两个坐标系 14*14的
8
9
10 grid = np.meshgrid(grid_w, grid_h) # here w goes first
11 #这就是⼀个坐标系了不过谁是x 谁是y还要看看
12 grid = np.stack(grid, axis=0)
13 # ⽣成了两个⽹格。每个都是14*14 grid现在是(2,14,14)
14
15
16 grid = shape([2, 1, grid_size, grid_size])
17 #(2,1,14,14)
18
19
然后继续进⼊下层函数 我们继续看 。
1 pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
2
3
4def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
5 assert embed_dim % 2 == 0
6
7 # use half of dimensions to encode grid_h
8 emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
9 emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 10
11 emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
12 return emb
再进⼊下层函数 。
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论