构造函数__init__

def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):
  • 初始化函数定义了网络的主要结构和参数。
  • channel: 输入特征的通道数。
  • dim: Transformer部分的特征维度。
  • depth: Transformer的层数。
  • kernel_size: 卷积层的核大小。
  • patch_size: 将图像分割为patches的尺寸。
  • mlp_dim: Transformer中前馈网络的维度。
  • dropout: Dropout比率,用于正则化。

网络层的定义

self.mv01 = IRBlock(channel, channel)self.conv1 = conv_nxn_bn(channel, channel, kernel_size)self.conv3 = conv_1x1_bn(dim, channel)self.conv2 = conv_1x1_bn(channel, dim)self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
  • IRBlockconv_nxn_bn, conv_1x1_bn用于特征提取和维度变换。
  • UserDefined 是之前提到的基于Transformer的结构,用于处理序列数据。
  • 这些层的组合利用了CNN的空间特征提取能力和Transformer的序列处理能力。

def conv_1x1_bn(inp, oup):

return nn.Sequential(

nn.Conv2d(inp, oup, 1, 1, 0, bias=False),

nn.BatchNorm2d(oup),

nn.SiLU()

)

def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):

return nn.Sequential(

nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False),

nn.BatchNorm2d(oup),

nn.SiLU()

)

前向传播 forward

def forward(self, x):y = x.clone()x = self.conv1(x)x = self.conv2(x)z = x.clone()_, _, h, w = x.shapex = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)x = self.transformer(x)x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)x = self.conv3(x)x = torch.cat((x, z), 1)x = self.conv4(x)x = x + yx = self.mv01(x)return x
  • forward方法定义了数据通过网络的流程。
  • 输入 x 首先经过几个卷积层进行特征提取和维度变换。
  • 输入被重组(rearrange),准备送入Transformer结构。
  • Transformer处理重组后的数据再被重组回原来的形状。
  • 经过进一步的卷积处理后,使用残差连接,并通过另一个 IRBlock

完整代码:

class MobileViTBv3(nn.Module):def __init__(self, channel, dim, depth=2, kernel_size=3, patch_size=(2, 2), mlp_dim=int(64*2), dropout=0.):super().__init__()self.ph, self.pw = patch_sizeself.mv01 = IRBlock(channel, channel) self.conv1 = conv_nxn_bn(channel, channel, kernel_size)self.conv3 = conv_1x1_bn(dim, channel)self.conv2 = conv_1x1_bn(channel, dim)self.transformer = UserDefined(dim, depth, 4, 8, mlp_dim, dropout)self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)def forward(self, x):y = x.clone()x = self.conv1(x)x = self.conv2(x)z = x.clone()_, _, h, w = x.shapex = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)x = self.transformer(x)x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h//self.ph, w=w//self.pw, ph=self.ph, pw=self.pw)x = self.conv3(x)x = torch.cat((x, z), 1)x = self.conv4(x)x = x + yx = self.mv01(x)return x