文章目录
[2020] (ICML)SimCLR
摘要
本文介绍了一个简单的视觉表征对比学习框架SimCLR。我们简化了最近提出的对比自监督学习算法,而不需要专门的架构或存储库。为了理解是什么使得对比预测任务能够学习有用的表示,我们系统地研究了我们框架的主要组成部分。我们表明(1)数据增强的组成在定义有效的预测任务中起着至关重要的作用,(2)在表示和对比损失之间引入一个可学习的非线性转换( a learnable nonlinear transformation),大大提高了学习表示的质量,(3)对比学习比监督学习受益于更大的批处理规模和更多的训练步骤(larger batch sizes and more training steps)。通过结合这些发现,我们能够大大超过以前在ImageNet上进行自我监督和半监督学习的方法。一个由SimCLR学习的自我监督表示训练的线性分类器达到了76.5%的top-1准确率,这比以前的最先进水平相对提高了7%,与有监督的ResNet-50的性能匹配。当只对1%的标签进行微调时,我们实现了85.8%的top-5准确率,在减少100个标签的情况下优于AlexNet。
Introduction
-
在没有人类监督的情况下学习有效的视觉表示是一个长期存在的问题。大多数主流的方法可以分为两类:生成式的和判别式的。
-
生成方法学习在输入空间中生成或建模像素(Hinton等人)。然而,像素级生成的计算成本很高,可能不是表示学习所必需的。 判别方法使用与监督学习类似的目标函数来学习表示,但训练网络执行借口任务(pretext tasks),其中输入和标签都来自于一个未标记的数据集。许多此类方法依赖启发式来设计借口任务,这可能会限制学习表征的普遍性。在潜在空间中,基于对比学习的判别方法最近显示出巨大的前景,取得了最新的成果。
-
在本研究中,我们介绍了一个简单的视觉表征对比学习框架,我们称之为SimCLR。SimCLR不仅优于以前的工作(图1),而且它也更简单,不需要专门的架构,也不是一个存储库。
-
为了理解是什么促成了良好的对比表征学习,我们系统地研究了我们的框架的主要组成部分,并展示了这一点:
-
在定义产生有效表示的对比预测任务时,多个数据增强操作的组合至关重要。此外,无监督对比学习得益于比监督学习更强的数据增强。
-
在表示和对比损失之间引入一个可学习的非线性转换,可以大大提高学习表示的质量。
-
具有对比交叉熵损失的表示学习受益于标准化嵌入(normalized embeddings)和适当调整的温度参数。
-
**对比学习与监督式学习相比,更大的批处理规模和更长的培训时间有利于对比学习。**与监督学习一样,对比学习得益于更深入更广泛的网络。
方法
The Contrastive Learning Framework
培训结束后,我们丢弃投影头g(·),使用编码器f(·)和表示h来处理下游任务。
受最近的对比学习算法的启发(见第7节概述),SimCLR通过在隐空间中的对比损失最大化同一数据示例的不同增强视图之间的一致性来学习表示。如图2所示,该框架包含以下四个主要组件。
SimCLR框架,正如全文所示,非常简单。取一幅图像,对其进行随机变换,得到一对增广图像x_i和x_j。该对中的每个图像都通过编码器以获得图像的表示。然后用一个非线性全连通层来获得图像表示z,其任务是最大化相同图像的z_i和z_j两种表征之间的相似性。
- 一个随机data augmentation模块,它将任意给定的数据示例随机转换为同一个示例的两个相关视图,表示为xi和xj,我们认为它们是正对。在这项工作中,我们依次应用了三个简单的增强:随机裁剪,然后调整到原始大小,随机颜色扭曲random color distortions和随机高斯模糊。如第3节所示,随机裁剪和颜色失真的结合是获得良好性能的关键。一种神经网络基编码器f(·),从扩充数据示例中提取表示向量。
- 我们的框架允许在没有任何限制的情况下选择各种网络架构。我们选择简单和采用常用的ResNet,得到hi = f(xi) = ResNet(xi),其中hi 是平均池化层后的输出。
- 一个小的神经网络投影头g(·),它将表示映射到应用对比损失的空间。我们使用一个隐含层的MLP得到zi = g(hi) = W(2)σ(W(1)hi),其中σ是一个ReLU非线性。如第4节所示,我们发现定义zi 而不是hi 上的对比损失是有益的。
- 为对比预测任务定义的对比损失函数。给定一个集合{x_k},其中包含一个正的例子xi和xj,对比预测任务的目的是分辨给定xi在{xk}(k!=i)中的xj。
我们随机抽取了N个样本的minibatch,并定义了由这个minibatch衍生的增强样本对的对比预测任务,得到2N个数据点。We do not sample negative examples explicitly。相反,给定一个正对,类似于(Chen et al., 2017),我们将一个minibatch中的其他2(N-1)增例视为负例。设
表示标准化u和V之间的点积(即余弦相似度)。则正例对(i, j)的损失函数定义为
其中[k!=i]{0,1}是在k!=i时计算到1的指示函数,τ表示温度参数。最后的损失是在一个小批处理中所有正对(i, j)和(j, i)上计算的。这个损失在之前的工作中已经使用过(Sohn, 2016;吴等人,2018;Oord et al., 2018);为了方便起见,我们将其命名为NT-Xent(标准化的温度尺度交叉熵损失the normalized temperature-scaled cross entropy loss)。
Training with Large Batch Size
为了保持简单,我们不使用记忆库来训练模型(Wu et al., 2018;他等人,2019)。相反,我们将训练批大小N从256更改为8192。批处理大小为8192的两个增强视图中,每个正对给出16382个反例。当使用具有线性学习速率缩放的标准SGD/Momentum时,大批量的训练可能不稳定(Goyal等人,2017)。为了稳定培训,我们对所有批大小使用LARS优化器(You等人,2017年)。我们使用云tpu训练我们的模型,根据批大小使用32到128核.
除非另有规定,为数据增加我们用随机作物和调整(随机翻转),颜色失真,和高斯模糊(有关的细节,请参阅附录A)。我们使用ResNet-50作为基础网络编码器, 2-layer MLP投影表示到128 -维潜在空间。作为损失,我们使用NT-Xent,使用LARS进行优化,学习率为4.8 (= 0.3 BatchSize/256),权重衰减为10e- 6。我们以4096批次培训100个epoch此外,我们对前10个周期使用线性预热,并使用余弦衰减(cosine decay)计划来衰减学习速率而不重启(Loshchilov &Hutter, 2016)。
对比表征学习的数据增强
图3 实矩形是图像,虚线矩形是随机作物。通过随机裁剪图像,我们抽样对比预测任务,包括全局到局部视图(B ->A)或邻近视图(D ->C)预测。
我们展示了这种复杂性可以通过对目标图像进行简单的随机裁剪(调整大小)来避免,这将创建包含上述两项的一系列预测任务,如图3所示。这种简单的设计选择方便地将预测任务与神经网络体系结构等其他组件解耦。广义的对比预测任务可以通过扩展增广函数族并将其随机组合来定义。