㊀第53卷第3期郑州大学学报(理学版)Vol.53No.3㊀2021年9月J.Zhengzhou Univ.(Nat.Sci.Ed.)Sep.2021
收稿日期:2020-10-20
基金项目:国家自然科学基金项目(61772561);湖南省重点研发计划项目(2018NK2012)㊂
作者简介:邵伟志(1996 ),男,硕士研究生,主要从事机器学习和机器视觉研究,E-mail:a825103775@163;通信作者:潘丽丽(1977 ),女,副教授,主要从事图像处理和深度学习研究,E-mail:lily_pan@163㊂
基于一致性正则化与熵最小化的半监督学习算法
邵伟志,㊀潘丽丽,㊀雷前慧,㊀黄诗祺,㊀马骏勇
(中南林业科技大学计算机与信息工程学院㊀湖南长沙410004)
摘要:在一致性正则化与熵最小化的基础上提出一种新的半监督学习算法Mean Mixup,集成数据的互补信息,然后使用熵最小化给未标记数据生成可靠的伪标签,在一致性正则化下进一步优化模型分类结果㊂在常用数据集
SVHN 和CIFAR10上对Mean Mixup 算法进行了评估,实验结果表明,所提出的方法在分类准确率上优于一些已有的半监督学习算法㊂
关键词:半监督学习;熵最小化;一致性正则化;伪标签
中图分类号:TP391㊀㊀㊀㊀㊀文献标志码:A㊀㊀㊀㊀㊀文章编号:1671-6841(2021)03-0079-06DOI :10.13705/j.issn.1671-6841.2020320
0㊀引言
在深度学习中,大量标记数据对于神经网络的训练是至关重要的㊂但是对于许多深度学习的任务而言,获取这些标记数据是较为困难的,比如医疗任务中每一个标记都需要从专家的结论中得出㊂此外,通过网络获取的信息很大一部分是较为私密的,怎样标记这些数据也是一个复杂的问题㊂半监督学习[1]通过让模型从未标记数据中获取信息来减少对于标记数据的依赖,对于图像搜索㊁文本分类㊁文档检索[2]等任务,半监督学习都能取得很好的结果㊂近年来,半监督学习方法聚焦于在损失函数中增加损失项,这些损失项一般都是通过未标记数据取得的,促使模型更好地利用未标记数据中的信息来对数据进行分类㊂半监督学习方法可以大致分为熵最小化[3]㊁一致性正则化[4]与传统正则化[5]三类,但是已有方法往往很容易忽视数据互补信息与多阶段模型共同作用的优势㊂在上述研究基础上,本文提出一种新的半监督算法Mean Mixup,重新考虑模型生成伪标签的方法,通过多种数据增强方法使得模型能够学习数据的互
补信息,并且让不同阶段的模型共同作用,最终让模型产生低熵预测来获取更准确的伪标签㊂为更好地应用一致性正则化,对标记数据和未标记数据进行混洗之后,根据数据类型传入不同的一致性损失函数,并使用比重系数调节来让模型能更好地从一致性正则化中受益㊂在常用数据集SVHN 和CIFAR10上的实验结果验证了新算法的有效性,其在分类准确率上优于Pseudo-label㊁Π-model 等半监督学习算法㊂
1㊀熵最小化与一致性正则化
1.1㊀熵最小化
在深度学习中,聚类假设指出,模型的决策边界最好不通过边缘数据分布的高密度区域,也就是使得模型输出分布的熵尽可能小,这样会使模型获得更好的泛化性[6]㊂体现在学习过程中的熵最小化是让模型对目标数据的分类结果尽量自信,使得模型的决策边界尽量远离边缘数据点,同时让模型的拟合曲线更贴合数据的边缘分布㊂图1对比了双月系统中原始决策边界与熵最小化约束下的决策边界㊂半监督学习算法中经常通过添加损失项来使得模型在未标记数据的概率分布实现熵最小化㊂Pseudo-label 算法[3]对未标记数据进行预测,利用熵最小化获得置信度高的预测分布作为伪标签,并将其用作标准交叉熵损失的训练目标[2]㊂
郑州大学学报(理学版)第53卷本文提出的Mean Mixup算法类似于Pseudo-label算法,都对未标记数据构
建伪标签,不同之处在于Mean Mixup算法对如何得到伪标签进行了新的设计㊂
图1㊀原始决策边界和熵最小化约束下的决策边界对比
Figure1㊀Comparison of original decision boundary and decision boundary under entropy minimization constraint
1.2㊀一致性正则化
一致性是指模型对受到扰动的数据点应输出相同的分布预测,半监督学习算法的很多突破性进展都是在一致性正则化基础上取得的㊂Π-model算法[4]通过在随机模型fθ(x)对同一样本的预测之间施加约束来实现一致性正则化㊂VAT算法[7]直接对输入x增加扰动,并且这种扰动能使得预测产生最大偏移,在受扰动样本与未受扰动样本产生的输出分布之间施加一致性约束㊂Mean teacher算法[8]通过构建教师-学生框架来实现一致性约束,在这个框架中使用了两个结构一致的网络,且教师网络的参数是学生网络参数的指数移动平均值,为更加直观,在本文中分别称为原型网络与指数网络㊂指数网络输入样本是对原型网络输入样本的加噪值,在两个网络的预测分布之间通过应用KL散度或者交叉熵函数来施加一致性约束㊂一致性正则化对于深度学习,尤其是半监督学习而言很有帮助,使得模型能够从标记数据的标签信息之外得到更多的高维特征信息㊂
2㊀Mean Mixup算法
2.1㊀算法概述
本文在网络结构的选择上使用原型网络与指数网络,并在构建伪标签的方法以及一致性损失函数的计算上进行了创新㊂给定一个批次数的标记样本X和一个同等批次数的未标记样本U,Mean Mixup算法对未标记数据产生一个伪标签,从而得到带伪标签的Uᶄ㊂标记数据使用数据增强得到Xᶄ,Xᶄ与Uᶄ分别与两者连接而成的数据W使用数据混洗得到X^和U^㊂主要的算法表达式有:xᶄ=augment(x),uᶄ=guesslabel(u),w= shuffle(concate(xᶄ,uᶄ)),x^,u^=mixup(xᶄ,uᶄ,w)㊂混洗得到的X^和U^传入网络后,分别计算带标记数据分类损失L x和未标记数据分类损失L u,此外还需要计算一致性损失L con,并在损失函数中使用λu和λc作为超参数来调节各损失项所占比重,本算法最终的损失函数为
L loss=L
x +λ
u
L
u
c
L con㊂(1)
2.2㊀集成互补信息的熵最小化伪标签
由于确认偏差[9]的原因,直接对未标记数据生成伪标签很容易对错误标签过度自信,从而不会继续从未标记数据中进行学习㊂Mean Mixup算法使用在同一模型的多个变种的共同作用下生成伪标签的方法,使伪标签的生成能获得多个角度的互补信息,对于未标记数据的判断更加可靠㊂为了得到未标记数据的软标签[10],使得未标记数据可以随着网络学习不断更新伪标签的生成结果,这种新的伪标签获取方法保证了网络能够从不同的角度和时间段受益,并逐渐提升伪标签的准确度㊂伪标签猜测流程如图2所示㊂
数据增强在半监督学习中能缓解标记数据的不足,而在Mean Mixup算法中也是生成伪标签的重要步骤㊂图2显示了数据增强的几种变化㊂对于标记数据,数据增强后得到Xᶄ,其中xᶄ=augment(x)㊂对于未08
㊀第3期邵伟志,等:
基于一致性正则化与熵最小化的半监督学习算法图2㊀伪标签猜测流程
Figure 2㊀Pseudo label guessing process
标记数据,应用数据增强K 次,得到未标记数据增强后的K 个实例,u ^k
=augment (u )㊂Mean Mixup 算法使用原型网络与指数网络给未标记数据生成一个预测分布,K 个增强实例传入模型的预测分布为
q ^1=12K ðK k =1(P model (y u ^k ;θ)+P ema-model (y u ^k ;θ))㊂(2)
㊀㊀为了使网络可以获得更多不同角度的信息[11],使用不同时间段的网络对未标记数据进行预测㊂为了专注于获取原样本的特征信息,仅将原未标记数据传入其他时间段的网络,并进行如下计算:
q ^=12(q ^1+P n-model (y u ^k ;θ)),(3)
其中:P n-model (y u ^k ;θ)表示前n 个训练轮时的网络,一般取n =5㊂如图2所示,为了使得预测分布的熵最小化,使用了锐化函数㊂对于预测分布q
^,应用调整分布 温度(T ) 的通用操作[12],锐化函数可以表示为q =SharPen (q ^,T )i =q ^1T i /ðL j =1q ^1T j ,(4)
其中:T 是超参数,T ң0时锐化函数的输出趋近 one-hot 编码㊂文献[12]中指出,降低 温度(T ) 有利于模型产生低熵预测㊂在运行过程中,算法对每一个批次的未标记数据都执行以上的方法计算伪标签,这种构建伪标签的方法使得伪标签的准确度随着模型学习不断提升㊂
2.3㊀混洗数据的一致性约束
一致性约束会使模型拥有更好的抗干扰能力,以往一致性约束通常是添加在网络的预测分布之间,区别
是对输入样本加噪或者是网络参数的变化㊂但在很多应用中,标记数据与未标记数据经常出现分布不匹配,甚至某一类的标记样本数极少,模型难以获取足够的信息㊂对数据使用Mixup 进行混洗来弥补两类数据之间的差异,使得模型学习的拟合曲线更符合数据分布,同时Mixup 还实现了传统正则化对于网络的调节作用[13]㊂Mixup [12]中对于两个带标签的样本(x 1,P 1)和(x 2,P 2),其混合后的目标(xᶄ,Pᶄ)为xᶄ=λᶄx 1+(1-λᶄ)x 2;Pᶄ=λᶄP 1+(1-λᶄ)P 2,(5)其中:λᶄ=max(λ,1-λ),λ在Beta (α,α)内取值,α是超参数㊂W 分别对Xᶄ和Uᶄ进行Mixup 混洗,从而得
到了新的数据X
^和U ^,此时两者都带有标签,可用于交叉熵等损失函数㊂一致性约束的实现方式如图3所示,数据分别传入指数网络和原型网络,计算预测分布之间的差异㊂
未标记数据的一致性损失L c 1和标记数据的一致性损失L c 2可以分别表示为L c 1=1u ^ðu ,q ɪu ^ P ema-model (y u +ε;θ)-P model (y u ;θ) 22,(6)
1
8
郑州大学学报(理学版)第53卷L c 2=1x ^ðx ,P ɪx ^ P ema-model (y x +ε;θ)-P model (y x ;θ) 2
2㊂
正则化 归一化(7)
图3㊀一致性约束的实现方式Figure 3㊀Implementation of consistency constraint 2.4㊀损失函数模型通过损失函数L loss 来进行梯度计算并更新参数㊂式(1)中L x 为使用交叉熵函数计算的标记数据X
^分类损失,L x =1xᶄðx ,P ɪxᶄH (P ,P model (y x ;θ))㊂(8)㊀㊀L u 为使用L2损失函数计算的未标记数据U
^分类损失,L u =1L uᶄðu ,q ɪuᶄ q -P model (y u ;θ) 22㊂(9)
㊀㊀损失函数中一致性损失项L con 为L c 1与L c 2之和,即L con =L c 1+L c 2㊂损失函数中的未标记数据分类损失和一致性损失通
过L2损失函数计算㊂L2损失函数与交叉熵不同,它是有界的,而且对完全错误的判断不太敏感,经常用作半监督学习中对未
标记数据预测的损失以及预测结果不确定性的度量[14]㊂
3㊀实验结果及分析
本文将提出的Mean Mixup 算法在TensorFlow2.0平台上实现,并与Mean teacher [8]㊁VAT [7]㊁Π-mode
l [4]㊁MixMatch [14]以及Pseudo-label [3]算法进行了比较㊂所有算法选择的网络均为 Wide ResNet-28-2 结构,但并没有使用学习率周期表而只使用了学习率衰减,选取运行100轮后得到的结果进行对比㊂Pseudo-label 与MixMatch 算法的对比结果是在TensorFlow2.0平台上进行复现得到的,其他算法的实验结果来自文献[15],选取的对比指标为错误率㊂MixMatch 算法根据文献[14]选择超参数与学习率,并选取运行100轮后的结果作对比㊂
3.1㊀CIFAR10数据集CIFAR10是一个深度学习常用数据集,包含50000张训练样本以及10000张测试样本,每个样本都是32∗32的RGB 图片,并且分属于10个类别,类别各自独立,不会产生重叠㊂遵循常规半监督学习的设置,实验中使用了4000个带标记的样本㊂设定Mean Mixup 算法学习率为0.002,对输入图片只进行了归一化处理㊂结果表明,Pseudo-label㊁Mean teacher㊁VAT㊁Π-model㊁MixMatch 算法的错误率分别为15.54%㊁15.87%㊁13.86%㊁16.37%㊁7.24%,而Mean Mixup 算法的错误率仅为6.37%㊂从实验结果可知,在CIFAR10数据集中VAT 算法比同样使用一致性正则化作为主要指导思想的Mean teacher 算法表现要好,这可能是由于噪声的方向选择能够使得模型更好地学习㊂MixMatch 和Mean Mixup 算法的错误率比单纯一致性正则化的Mean teacher㊁VAT 以及Π-model 算法低,这证明了在半监督学习中使用熵最小化构建伪标签是有效的㊂为了对Mean Mixup 算法进行更详细的实验论证,分别在CIFAR10数据集中选择了250㊁500㊁1000㊁2000个标签进行100轮的实验,算法的错误率结果分别为18.70%㊁1
4.86%㊁11.42%㊁7.64%㊂更少标签数据下的实验结果表明,在相同网络架构中,Mean Mixup 算法在仅使用2000个标记样本的情况下接近甚至超过经典半
监督算法的表现,证明了Mean Mixup 算法对于标签样本的利用率更高㊂
3.2㊀SVHN 数据集SVHN 数据集来源于谷歌街景门牌号码,经过裁剪成为32∗32的RGB 图片,包含73257个训练样本和26032个测试样本,被划分为10个类别,设定学习率为0.002㊂将Mean Mixup 算法与经典半监督算法在使用4000个标签样本运行100轮的实验结果进行了对比㊂结果表明,Pseudo-label㊁Mean teacher㊁VAT㊁Π-model㊁MixMatch 算法的错误率分别为5.37%㊁5.65%㊁6.31%㊁7.19%㊁3.89%,而Mean Mixup 算法的错误率仅为2.87%㊂在相同的标签数据下,Mean Mixup 算法的分类错误率较其他半监督算法更低,并且相较于使用单一正则化的Pseudo-label 等方法优势较为明显㊂同时,在CIFAR10数据集中比Mean teacher 算法表现
2
8
㊀第3期邵伟志,等:基于一致性正则化与熵最小化的半监督学习算法更好的VAT 算法在SVHN 数据集中并没有体现出优势,表明了在难以获得足够多的标签信息的半监督学习中,只使用一致性正则化或熵最小化很难获得出众的结果㊂为验证Mean Mixup 算法在更少标签数据下的表
现,进行了四组少标签(250㊁500㊁1000㊁2000个标签)数据实验,算法的错误率结果分别为9.13%㊁8.07%㊁6.58%㊁5.30%㊂Mean Mixup 算法在只有2000个标签数据的情况下依然取得了错误率为5.30%的成绩,这与4000个标签数据下Pseudo-label 算法的结果相近,且高于VAT 和Π-model 算法,再一次验证了Mean Mixup 算法对于标签样本的利用率更高
㊂图4㊀伪标签准确率结果Figure 4㊀Pseudo label accuracy results
3.3㊀标签猜测准确度
为了验证所得伪标签的准确度,将使用验证集猜测
得到的标签与其自带的标签进行对比,每隔20轮记录下
准确率,伪标签准确率结果如图4所示㊂可以看出,通过
集成数据互补信息进而获得低熵伪标签的方法是有效
的,且在250个标签数据的情况下所得的伪标签准确率
也达到了85.78%,与4000个标签数据下的准确率差距不大,这表明Mean Mixup 算法在标签数据稀少的情况下生成伪标签的准确度依然较高㊂
3.4㊀超参数选择
在Mean Mixup 算法中有四个较为重要的超参数,分别为未标记数据增强次数K ㊁Mixup 中取样区间λ以及未
标记数据分类损失与一致性损失各自的比重系数λu 和λc ㊂为了更直观地展示超参数的选择,同时避免超参数细微变化所带来的不公平的性能比较,仅选择了四
组超参数在数据集中进行实验,其中依照MixMatch 算法中的设置使得α=0.75,遵循Mean teacher 算法使得λc =1㊂对于每组超参数,均使用4000个标签,所应用的数据预处理方式与优化器都是一致的,最终选取其运行100轮后的错误率来进行对比㊂结果表明:K =1,λu =75时错误率为7.37%;K =1,λu =150时错误率为7.86%;K =2,λu =75时错误率为6.37%;K =2,λu =150时错误率为6.65%㊂从四组不同超参数对比实验结果中可以看出,未标记数据增强次数K 对于结果的影响较大,这是由于在生成伪标签的过程中,多个不同增强实例的反馈能增强伪标签的准确度㊂而一致性损失值在实验过程中一直较小,需要使用较大的比重系数才能使网络从一致性损失中进行学习,所以直接从75增大到150对于实验结果的影响也不明显㊂因此,对比实验中超参数的选择为K =2,α=0.75,λu =75,λc =1㊂
4㊀结论
本文针对以往半监督算法往往忽略数据互补信息的不足,提出了一种新的半监督算法Mean Mixup㊂该方法能够有效利用少量标签带来的信息,并推广到未标记数据上㊂Mean Mixup 算法基于熵最小化与一致性正则化的思想,设计了通过多阶段模型共同作用,集成多角度信息从而生成低熵伪标签的方法,并利用一致性正则化优化了模型的分类性能㊂在经典数据集CIFAR10和SVHN 上与现有的半监督算法进行了
比较,实验结果表明,在相同标签数的情况下,Mean Mixup 算法的分类准确度较之前的半监督方法表现更好㊂即使在更少标签数据的情况下,Mean Mixup 算法获得的准确度也超过了之前使用单一正则化的半监督方法㊂本文还验证了生成伪标签的准确度,发现即使在标签数据稀少的情况下,生成伪标签的准确度依然较高,表明Mean Mixup 在解决半监督学习问题上是有效的,且集成数据信息生成伪标签的方法是正确的㊂
参考文献:
[1]㊀ZHU X J,GOLDBERG A B.Introduction to semi-supervised learning [M ].San Rafael:Morgan and Claypool Publishers,
2009.
[2]㊀刘欢,徐健,李寿山.基于变分自编码器的情感回归半监督领域适应方法[J].郑州大学学报(理学版),2019,51(2):47-51.
38

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。