Pytorch之ResNet图像分类( 四 )


然后就可以根据论文中和网络结构表格搭建网络,
# 残差网络结构class ResNet(nn.Module):def __init__(self,block,# 残差模块:BasicBlock / Bottleneckblocks_num,# conv2/3/4/5_x 残差模块的数量,查看论文中给出的网络配置表格num_classes=1000,include_top=True,# 用于外部模块调用ResNet网络groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_topself.in_channel = 64# conv_2的输入通道数是64,经过7x7卷积和3x3最大池化层后维度为64self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 堆叠残差模块: conv2/3/4/5_xself.layer1 = self._make_layer(block, 64, blocks_num[0])# conv2_Xself.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)# conv3_Xself.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)# conv4_Xself.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)# conv5_Xif self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))# output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# 用于生产conv2/3/4/5_x各层配置# block:选用的残差结构的模块,resnet18/34:BasicBlock resnet50/101/152:Bottleneck# channel:conv2/3/4/5_x各层第一层的输入通道数# block_num:conv2/3/4/5_x各层堆叠的次数# stride:默认为1,从conv3_x开始stride=2def _make_layer(self, block, channel, block_num, stride=1):downsample = None# 对于resnet18/34不执行该语句,50,101,152:conv2_x列残差结构的第一层也是虚线残差结构,需要调整输入特征矩阵的channelif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []# 第一层残差结构layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))# 更新下一层残差结构的输入通道self.in_channel = channel * block.expansion# conv2/3/4/5_x的第二层残差结构都为实线残差结构for _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x
使用类调用生产/网络:
def resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)def resnext50_32x4d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pthgroups = 32width_per_group = 4return ResNet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)def resnext101_32x8d(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pthgroups = 32width_per_group = 8return ResNet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,include_top=include_top,groups=groups,width_per_group=width_per_group)