HRNet网络代码解读:Deep High( 三 )


__ init __ 函数
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,num_channels, fuse_method, multi_scale_output=True):super(HighResolutionModule, self).__init__()self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)self.num_inchannels = num_inchannelsself.fuse_method = fuse_methodself.num_branches = num_branchesself.multi_scale_output = multi_scale_outputself.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)self.fuse_layers = self._make_fuse_layers()self.relu = nn.ReLU(True)
函数
def _make_branches(self, num_branches, block, num_blocks, num_channels):branches = []# 反复堆叠_make_one_branch,重复num_branches次数for i in range(num_branches):branches.append(self._make_one_branch(i, block, num_blocks, num_channels))return nn.ModuleList(branches)
函数
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,stride=1):downsample = Noneif stride != 1 or \self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.num_inchannels[branch_index],num_channels[branch_index] * block.expansion,kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(num_channels[branch_index] * block.expansion,momentum=BN_MOMENTUM),)layers = []# 第一个layer接入downsample,但是这里不会进行下采样,堆叠一次basicblocklayers.append(block(self.num_inchannels[branch_index],num_channels[branch_index],stride,downsample))# 通道数[32, 64]self.num_inchannels[branch_index] = \num_channels[branch_index] * block.expansion# num_blocks 为 [4, 4],所以循环次数为3,重复堆叠basicblockfor i in range(1, num_blocks[branch_index]):layers.append(block(self.num_inchannels[branch_index],num_channels[branch_index]))return nn.Sequential(*layers)
以上这部分为堆叠四次,对应图中中左侧的部分 。
函数
def _make_fuse_layers(self):if self.num_branches == 1:return Nonenum_branches = self.num_branches # 2num_inchannels = self.num_inchannels # [32, 64]fuse_layers = []# 把j分支的特征融入到i分支中 。for i in range(num_branches if self.multi_scale_output else 1):fuse_layer = []for j in range(num_branches):if j > i:# 如果j分支大于i分支,则说明j下采样倍率更高,需要进行上采样与i分支融合 。fuse_layer.append(nn.Sequential(nn.Conv2d(num_inchannels[j],num_inchannels[i],1, 1, 0, bias=False),nn.BatchNorm2d(num_inchannels[i]),nn.Upsample(scale_factor=2**(j-i), mode='nearest')))elif j == i:# j分支等于i分支,不需要进行操作fuse_layer.append(None)else:# j分支大于i分支,需要进行下采样,这里stride = 2# 判断k是否是最后一层,不是最后一层需要加Relu激活函数,最后一层则不需要添加conv3x3s = []for k in range(i-j):if k == i - j - 1:num_outchannels_conv3x3 = num_inchannels[i]conv3x3s.append(nn.Sequential(nn.Conv2d(num_inchannels[j],num_outchannels_conv3x3,3, 2, 1, bias=False),nn.BatchNorm2d(num_outchannels_conv3x3)))else:num_outchannels_conv3x3 = num_inchannels[j]conv3x3s.append(nn.Sequential(nn.Conv2d(num_inchannels[j],num_outchannels_conv3x3,3, 2, 1, bias=False),nn.BatchNorm2d(num_outchannels_conv3x3),nn.ReLU(True)))fuse_layer.append(nn.Sequential(*conv3x3s))fuse_layers.append(nn.ModuleList(fuse_layer))return nn.ModuleList(fuse_layers)
self.final_layer = nn.Conv2d(in_channels=pre_stage_channels[0],out_channels=cfg['MODEL']['NUM_JOINTS'],kernel_size=extra['FINAL_CONV_KERNEL'],stride=1,padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0)
的参数为k = 1, s = 1, p = 0,= 17对应17个关键点 。
后记
其中关键的部分已经再代码中以注释的形式展现,请认真读注释 。