Contrastive Leanring

对比学习(Contrastive Learning)是机器学习领域的一种自监督学习方法,其核心思想是通过让模型区分相似(正样本)与不相似(负样本)的数据,学习数据的高效特征表示。它不需要人工进行标注数据的标签,而是利用数据本身的潜在结构来构建监督信号,可以通俗地理解成让模型学习"什么像什么,什么不像什么"

代理任务(pretext task)定义如何生成和利用正负样本对,通过模型对这些样本的对比学习,最终提取出通用的特征表示,能够将无标签数据转化为"伪监督信号",代替人工标注。而不同的代理任务会通过不同的规则构造正负样本对,常见的代理任务及其正负样本构造方式如下:

pretext task 正样本构造方式 负样本构造方式
实例判别(instance discrimination) 同一图片的不同增强版本(裁剪、旋转等) 不同图片的增强版本
拼图还原(Jigsaw Puzzle) 同一图片的不同拼图块 不同图片的拼图块
时序预测(Temporal Prediction) 同一视频的相邻帧 不同视频的随机帧
掩码预测(Masked Prediction) 同一文本的不同掩码版本 不同文本的掩码版本

对比学习的目标是在特征空间中,拉近正样本对,推开负样本对,使相似样本的嵌入(embedding)距离更近,不相似样本的嵌入距离更远,这里的相似度通常用欧几里得距离余弦相似度来度量。

对比学习通常采用InfoNCE(Noise Contrastive Estimation)损失函数:

\[ \mathcal{L}=-\text{log}\frac{\text{exp}(\text{sim}(q,k_+)/\tau)}{\sum_{i=0}^{K}\text{exp}(\text{sim}(q,k_i)/\tau)} \]

  • \(q\): 查询样本的特征
  • \(k_+\): 正样本的特征
  • \(k_i\): 负样本的特征
  • \(\tau\): 温度参数,控制分布的尖锐程度
  • \(\text{sim}(\cdot)\): 相似度函数

直观来看,这不就是常见的CrossEntropyLoss吗。没错,但是在对比学习领域,这个损失函数会有很大的不同。首先,在有监督学习中,K通常指训练样本的类别数,但是在对比学习中,如果使用instance discrimination作为代理任务的话,那K将会是非常大的数字,即除该样本外其余所有样本的总和,因此,如果直接使用交叉熵损失函数,训练的时间将会非常的漫长。而NCE(Noise Contrastive Estimation)将所有样本分为两类,data sample和noise sample,每次训练时就将这两种样本进行对比即可。但是仅仅只做这一个改变还不行,无法解决单次训练样本数过多的问题,NCE每次仅从负样本中取出小部分的样本去做估计,这就是Estimation的含义。而InfoNCE是NCE一个变体,将二分类问题转换为多分类问题。

人脸验证(Face Verification)

人脸验证通常被另一个常见的计算机视觉任务混淆,即人脸识别。这两者密切相关,但目标不同

任务类型 人脸识别(Face Recognition) 人脸验证(Face Verification)
核心目标 一对多比对: 判断输入人脸属于已知数据库中的哪一个个体 一对一比对: 判断两张人脸是否属于同一个人
输出形式 输出个体身份ID 输出二元判定结果:“是同一个人”或“不是同一个人”
技术本质 多分类问题 二分类问题(相似性阈值判定)

MoCo(Momentum Contrast)

Momentum Contrast for Unsupervised Visual Representation Learning Kaiming He et al. | arXiv 1911.05722 | Code | CVPR 2020 | Facebook AI Research (FAIR)

MoCo使用一个动态的负样本队列,存储历史批次的特征作为负样本,可以显著增加负样本数量而无需增大batch_size。并且,MoCo将动量的思想引入编码器中,由于负样本队列中的特征来自不同时刻的编码器,参数不一致会导致噪声的干扰,而动量的引入使得负责编码负样本的编码器的参数只会缓慢随着查询编码器进行更新,能保证负样本特征的一致性

MoCo前的模型

在MoCo前有两种经典的对比学习模型:SimCLR, Memory Bank

SimCLR arXiv:2002.05709

SimCLR是一个End to End模型,其负样本仅使用当前批次的样本作为负样本,并且对query encoder和key encoder都进行反向传播。这种方法虽然特征的一致性高,但是为了保证负样本特征的数量足够大,batch_size会特别大,对显存要求极高

Memory Bank arXiv:1805.01978v1

MemoryBank通过维护一个外部存储库即Memory Bank,保存历史样本的特征:

  • 每个样本的特征被缓存在Memory Bank中
  • 更新时,使用旧模型参数生成新特征并替换旧值

虽然该方法支持海量的负样本,但是由于每次仅更替其中的部分负样本特征,因此特征一致性较差(旧参数生成的特征可能与新参数生成的特征在时间上差距很大)

MoCo的改变

MoCo结合的上述两种模型的思想,通过两种机制解决支持大量负样本和特征一致性问题

  • MoCo引入动态队列存储动量编码器生成的历史特征,通过队列的出队与入队保证了可以支持海量的负样本同时只需要较小的batch size
  • 引入动量机制,query encoder可以通过反向传播更新参数,而key encoder即momentum encoder需要通过动量机制更新参数,使得key encoder的参数只随query encoder缓慢更新,保证负样本的一致性

\[ \theta_k \leftarrow m\theta_k + (1-m)\theta_q \]

其中:

  • \(\theta_k\)是key encoder的参数
  • \(\theta_q\)是query encoder的参数
  • m是动量系数,m越接近则key encoder参数随着query encoder参数的变化越慢,原论文中作者将其设置为一个很大的数值,0.999

Loss计算上,MoCo使用InfoNCE作为损失函数

原论文给出了模型的伪代码,更方便我们理解MoCo模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# f_q, f_k: query encoder, key encoder
# queue: 负样本队列, CxK
# m: 动量系数
# t: 温度系数
f_k.params = f_q.params # 初始化,将query encoder的参数复制给key encoder
for x in loader:
# 对正负样本进行相同的数据增强
x_q = augmentation(x)
x_k = augmentation(x)

q = f_q(x_q) # (N, C)
k = f_k(x_k).detach() # (N, C) 并使用detach()阻止反向传播

l_pos = torch.bmm(q.view(N, 1, C), k.view(N, C, 1)) # 正样本对的logits (N, 1)
l_neg = torch.mm(q.view(N, C), queue.view(C, K)) # 负样本对的logits (N, K)

logits = torch.cat([l_pos, l_neg], dim=1) # 合并logits (N, K+1)
labels = torch.zeros(N, dtype=torch.long).cuda()

loss = CrossEntropyLoss(logits / t, labels)
loss.backward()
update(f_q.params) # 更新query encoder参数

f_k.params = m * f_k.params + (1-m) * f_q.params # 更新key encoder参数
# 入队与出队
enqueue(queue, k)
dequeue(queue)

实战部分

环境配置

  • PyTorch 2.1.2 Python 3.10 以及对应版本的torchvision
  • ubuntu22.04
  • GPU: 单块RTX4090(24GB)

需要的库

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np
import random
import time
import shutil
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision import models
import os
import glob

数据预处理

常见的人脸验证数据集有 CASIA-WebFace、Celeb等

这里采用从kaggle上下载的CASIA-WebFace数据集进行训练,并经过筛选,最终数据集大小为211958张图片。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# train.ipynb
root_dir = "./data/data/casia-webface/"
image_root_name = [num for num in os.listdir(root_dir) if num.isdigit()]
image_root_name.sort()
image_dir = [root_dir + num + "/" for num in image_root_name]

# 可视化数据
def show_images(image_dir):
images = glob.glob(image_dir + "*.jpeg")
fig, axes = plt.subplots(1, len(images), figsize=(10, 5))
for i, image in enumerate(images):
axes[i].imshow(plt.imread(image))
axes[i].axis("off")
plt.show()

for i in range(5):
show_images(image_dir[i])

参数设置

1
2
3
4
5
6
7
8
9
10
11
# train.ipynb
seed = 2025
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
batch_size = 256
lr = 0.03
momentum = 0.9
weight_decay = 1e-4
epochs = 200
print_freq = 10

数据增强

MoCo结合了多种数据增强策略,以生成多样化的正样本对。MoCo v2在SimCLR的基础上优化了数据增强组合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# ./moco/loader.py
import random
from PIL import ImageFilter

class TwoCropsTransform:
def __init__(self, base_transform):
'''
将一张图片的两个随机crop作为query和key
'''
self.base_transform = base_transform

def __call__(self, x):
q = self.base_transform(x)
k = self.base_transform(x)
return q, k

class GaussianBlur:
'''
Gaussian blur augmentation: https://arxiv.org/abs/2002.05709
'''
def __init__(self, sigma=[0.1, 2.0]):
self.sigma = sigma

def __call__(self, x):
sigma = random.uniform(self.sigma[0], self.sigma[1])
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# train.ipynb
import moco.loader


normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
# moco v2 data augmentation
augmentation = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomApply(
# brightness, contrast, saturation, hue
[transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8
),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([moco.loader.GaussianBlur([0.1, 2.0])], p=0.5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])

生成训练数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# train.ipynb
train_dir = root_dir
# datasets.ImageFolder: 用于加载图像数据集的通用数据加载器
# 适用于文件夹中每个类别一个文件夹的情况
# root/class1/image1.jpg
# root/class1/image2.jpg
# root/class2/image1.jpg
# ...
train_dataset = datasets.ImageFolder(
train_dir, moco.loader.TwoCropsTransform(augmentation)
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=True
)

MoCo模型框架

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# ./moco/model.py
import torch
import torch.distributed
import torch.nn as nn
from torch.nn import functional as F

class MoCo(nn.Module):
'''
MoCo: https://arxiv.org/abs/1911.05722
'''
def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
'''
@param:
base_encoder: base encoder network
dim: 特征维度, feature_dimension
K: queue size, 负样本的数量
m: momentum for updating key encoder
T: temperature parameter
mlp: 是否使用mlp
'''
super(MoCo, self).__init__()

self.K = K
self.m = m
self.T = T

# encoders, num_classes是fc的输出维度
self.encoder_q = base_encoder(num_classes=dim)
self.encoder_k = base_encoder(num_classes=dim)
if mlp:
dim_mlp = self.encoder_q.fc.weight.shape[1]
self.encoder_q.fc = nn.Sequential(
nn.Linear(dim_mlp, dim_mlp),
nn.ReLU(),
self.encoder_q.fc
)
self.encoder_k.fc = nn.Sequential(
nn.Linear(dim_mlp, dim_mlp),
nn.ReLU(),
self.encoder_k.fc
)

for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False

# queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

@torch.no_grad()
def _momentum_update_key_encoder(self):
'''
更新key encoder
'''
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
'''
出队入队
'''
keys = concat_all_gather(keys)
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0

# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr

@torch.no_grad()
def _batch_shuffle_single_gpu(self, x):
'''
打乱batch
'''
# random shuffle
idx_shuffle = torch.randperm(x.shape[0])
# idx_unshuffle = torch.argsort(idx_shuffle)
idx_unshuffle = torch.empty_like(idx_shuffle, device=x.device)
idx_unshuffle[idx_shuffle] = torch.arange(x.shape[0], device=x.device)
x = x[idx_shuffle]
return x, idx_unshuffle

@torch.no_grad()
def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
'''
恢复batch
'''
return x[idx_unshuffle]

def forward(self, im_q, im_k):
'''
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
logits, targets
'''
q = self.encoder_q(im_q)
q = F.normalize(q, dim=1) # q: NxC

with torch.no_grad():
self._momentum_update_key_encoder()
im_k, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)
k = self.encoder_k(im_k)
k = F.normalize(k, dim=1)
k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)

# compute logits
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# l_pos = F.cosine_similarity(q, k, dim=1).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# l_neg = F.cosine_similarity(q.unsqueeze(1), self.queue.clone().detach().T.unsqueeze(0), dim=2)
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.T
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
# dequeue and enqueue
self._dequeue_and_enqueue(k)
return logits, labels

@torch.no_grad()
def concat_all_gather(tensor):
'''
所有进程的tensor拼接
'''
return tensor

论文源码中,作者是运行在分布式多GPU环境下,这里修改为仅在单GPU下训练

模型以及优化器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# train.ipynb
def base_encoder(num_classes=128):
model = models.resnet50(pretrained=True)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

return model

model = MoCo(
base_encoder,
dim=128,
K=65536,
m=0.999,
T=0.07,
mlp=True
)
model = model.cuda()

criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(
model.parameters(),
lr=lr,
momentum=momentum,
weight_decay=weight_decay
)

模型训练

模型训练过程中相关数据的记录

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# ./manager.py
class AverageMeter:
'''计算并记录均值和当前值'''
def __init__(self, name, fmt=":f"):
self.name = name
self.fmt = fmt
self.reset()

def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0

def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)

class ProgressMeter:
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix

def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print("\t".join(entries))

def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# train.ipynb
# 正样本对的相似度在topk的比例,反映模型对相似样本的容错能力
def accuracy(output, target, topk=(1,)):
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))

return res
# cosine annealing learning rate schedule
def adjust_learning_rate(optimizer, epoch, lr):
lr *= 0.5 * (1.0 + np.cos(np.pi * epoch / epochs))
for param_group in optimizer.param_groups:
param_group["lr"] = lr

def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, "model_best.pth.tar")

def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter("Time", ":6.3f")
data_time = AverageMeter("Data", ":6.3f")
losses = AverageMeter("Loss", ":.4e")
top1 = AverageMeter("Acc@1", ":6.2f")
top5 = AverageMeter("Acc@5", ":6.2f")
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
prefix="Epoch: [{}]".format(epoch)
)

model.train()

end = time.time()
for i, (images, _) in enumerate(train_loader):
# 记录数据加载时间
data_time.update(time.time() - end)

if torch.cuda.is_available():
images[0] = images[0].cuda(non_blocking=True)
images[1] = images[1].cuda(non_blocking=True)

output, target = model(im_q = images[0], im_k = images[1])
loss = criterion(output, target)

acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images[0].size(0))
top1.update(acc1[0], images[0].size(0))
top5.update(acc5[0], images[0].size(0))

optimizer.zero_grad()
loss.backward()
optimizer.step()

batch_time.update(time.time() - end)
end = time.time()

if i % print_freq == 0:
progress.display(i)

for epoch in range(epochs):
adjust_learning_rate(optimizer, epoch, lr)
train(train_loader, model, criterion, optimizer, epoch)
save_checkpoint(
{
"epoch": epoch + 1,
"base_encoder": "resnet50",
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict()
},
is_best=False,
filename="checkpoint_{:04d}.pth.tar".format(epoch)
)