admin 管理员组

文章数量: 1184232

背景

伪标签半监督学习方法中,伪标签的选择不容易,在模型训练初期容易选出误差较大的伪标签导致模型性能不佳;unsupervised loss中的权重系数不好确定。

提出方法:MixMatch

将当前用于半监督学习的主要方法相结合,以生成一种新算法 MixMatch,该算法猜测数据增强未标记示例的低熵标签,并使用 MixUp 混合标记和未标记数据。

只采用了250个标签,就减小了错误率。

半监督学习(SSL) 旨在通过允许模型利用未标记数据来很大程度上减轻对标记数据的需求。其中一种半监督学习方法是在损失函数中添加一个损失项,该损失项是在未标记的数据上计算的,并鼓励模型更好地泛化到看不见的数据。

其中损失项属于以下三类之一:

熵最小化——它鼓励模型对未标记的数据输出可信的预测;

一致性正则化——它鼓励模型在其输入受到扰动时产生相同的输出分布;

通用正则化——它鼓励模型很好地泛化并避免过度拟合训练数据。

MixMatch,是一种 SSL 算法。它引入了一个单一的损失,可以将以上损失项统一到一个半监督学习方法中。与以前的方法不同,MixMatch 为未标记数据引入了一个统一的损失项,可以无缝地降低熵,同时保持一致性并与传统的正则化技术保持兼容。

方法核心:MixUp

Label Guessing

标签猜测,目的是达到和伪标签相同的作用,但与伪标签不同。

对于未标记数据集中的每个未标记示例,MixMatch 使用模型的预测为示例的标签生成一个“猜测”。这个猜测后来被用在无监督损失项中

Sharpening

锐化

受半监督学习中熵最小化的启发,在生成标记猜测时,我们执行了一个额外的步骤。给定对增强数据 的平均预测,应用锐化函数来降低标签分布的熵。 在实践中,通过调节“温度”的方式来控制锐化函数的使用范畴。

MixUp

:labeled data, :unlabeled data

有标记数据和无标记数据所取batch大小相同,其中无标记数据会经过K个增广

算法步骤

 首先对有标记的每个样本做数据增广;

接着对无标记的每个样本做K次数据增广,文中K=2,即做两次数据增广;

将经过增广的无标记数据送入模型,每一个数据会预测一个结果;对结果取均值,再求Sharpen,这就是label guessing的操作。

上述操作后,得到两份数据,第一份:增广后有标记的数据;第二份:增广后有标记的无标记数据(这里的标记是经过训练猜测出来的) 

将有标记的和无标记的数据concat起来,组合成一个大的batch,然后将其随机打乱,然后跟原始数据混合起来(MinUp)。

混合之后的数据再进入模型进行前向计算,最后求loss。

其中loss function定义:

模型建立

超参数

损失函数-指数移动平均 

an exponential moving average of model parameter values, it provides a more stable target and was found empirically to significantly improve results.

 通过指数平均来调整模型,使模型更稳定

rampup

we only change α and on a per-dataset basis; we found that = 0.75 and = 100 ar

本文标签: 核心 代码 Approach Holistic MixMatch