论文解读《Semi( 二 )


2.2交叉解码器知识蒸馏(CDKD)
引入了CDKD来增强MTNet利用未标记图像的能力,并消除带有噪声的伪标签的负面影响 。它迫使每个解码器都受到其他两个解码器的软预测的监督 。遵循KD[5]的做法,使用温度校准的 (T-)来软化概率图:
式中,zc表示像素c类的logit预测值,pc表示c类的软概率值 。温度T是控制输出概率软度的参数 。注意,T = 1对应的是一个标准的函数,T值越大,概率分布越软,熵越高 。当T式3 为锐化函数 。
令PcA、PsA和PcsA分别表示对三个分支的软概率图 。
另外两个分支为该分支的老师指导学习,CSA分支的KD损失为:
式中KL()为-散度函数 。请注意,
的梯度只反向传播到CSA分支,因此知识是从教师提炼到学生的 。同样,CA和SA分支的KD损失分别记为
。则总蒸馏损失定义为:
class KDLoss(nn.Module):"""Distilling the Knowledge in a Neural Networkhttps://arxiv.org/pdf/1503.02531.pdf"""def __init__(self, T):super(KDLoss, self).__init__()self.T = Tdef forward(self, out_s, out_t):loss = (F.kl_div(F.log_softmax(out_s / self.T, dim=1),F.softmax(out_t / self.T, dim=1), reduction="batchmean") # , reduction="batchmean"* self.T* self.T)return lossoutputs1, outputs2, outputs3 = model(inputs)kd_loss = KDLoss(T=10)cross_loss1 = kd_loss(outputs1.permute(0, 2, 3, 1).reshape(-1, 2),outputs2.detach().permute(0, 2, 3, 1).reshape(-1, 2)) + \kd_loss(outputs1.permute(0, 2, 3, 1).reshape(-1, 2),outputs3.detach().permute(0, 2, 3, 1).reshape(-1, 2))cross_loss2 = kd_loss(outputs2.permute(0, 2, 3, 1).reshape(-1, 2),outputs1.detach().permute(0, 2, 3, 1).reshape(-1, 2)) + \kd_loss(outputs2.permute(0, 2, 3, 1).reshape(-1, 2),outputs3.detach().permute(0, 2, 3, 1).reshape(-1, 2))cross_loss3 = kd_loss(outputs3.permute(0, 2, 3, 1).reshape(-1, 2),outputs1.detach().permute(0, 2, 3, 1).reshape(-1, 2)) + \kd_loss(outputs3.permute(0, 2, 3, 1).reshape(-1, 2),outputs2.detach().permute(0, 2, 3, 1).reshape(-1, 2))cross_consist = (cross_loss1 + cross_loss2 + cross_loss3)/3
KL散度(- ,简称KL散度)是一种度量两个概率分布之间差异的指标,也被称为相对熵( ) 。
2.3 基于平均预测的不确定性最小化
例如,两个分支分别预测像素的一种类别概率为0.0和1.0 。为了避免这个问题,并进一步鼓励解码间的一致性,我们提出了一种基于平均预测的不确定性最小化方法:
其中
为平均概率图 。C和N分别为类号和像素数量 。P是像素i处c类的平均概率 。
outputs1, outputs2, outputs3 = model(inputs)outputs1_soft = torch.softmax(outputs1, dim=1)outputs2_soft = torch.softmax(outputs2, dim=1)outputs3_soft = torch.softmax(outputs3, dim=1)outputs_avg_soft = (outputs1_soft+outputs2_soft+outputs3_soft)/3en_loss = entropy_loss(outputs_avg_soft, C=2)
最后,我们的 CDMA的整体损失函数 为:
其中
为标记训练图像上三个分支的平均监督学习损失,每个分支的监督学习损失计算概率预测(PcsA, PcA和PsA)与标签之间的Dice损失和交叉熵损失 。入1和入2分别是Lcdkd和Lum的权值 。
都应用于标记和未标记的训练图像 。
loss_sup = 0.5*dice_loss(outputs1_soft[:args.labeled_bs], labels[:args.labeled_bs])+0.5*F.cross_entropy(outputs1[:args.labeled_bs], labels[:args.labeled_bs,0,:,:].long()) + \0.5*dice_loss(outputs2_soft[:args.labeled_bs], labels[:args.labeled_bs])+0.5*F.cross_entropy(outputs2[:args.labeled_bs], labels[:args.labeled_bs,0,:,:].long()) + \0.5*dice_loss(outputs3_soft[:args.labeled_bs], labels[:args.labeled_bs])+0.5*F.cross_entropy(outputs3[:args.labeled_bs], labels[:args.labeled_bs,0,:,:].long())loss_sup = loss_sup/3
三、和其他方法对比