Contrastive Leanring
对比学习 (Contrastive
Learning)是机器学习领域的一种自监督学习方法,其核心思想是通过让模型区分相似(正样本)与不相似(负样本)的数据,学习数据的高效特征表示 。它不需要人工进行标注数据的标签,而是利用数据本身的潜在结构来构建监督信号,可以通俗地理解成让模型学习"什么像什么,什么不像什么" 。
代理任务(pretext
task) 定义如何生成和利用正负样本对,通过模型对这些样本的对比学习,最终提取出通用的特征表示,能够将无标签数据转化为"伪监督信号",代替人工标注。而不同的代理任务会通过不同的规则构造正负样本对,常见的代理任务及其正负样本构造方式如下:
pretext task
正样本构造方式
负样本构造方式
实例判别(instance discrimination)
同一图片的不同增强版本(裁剪、旋转等)
不同图片的增强版本
拼图还原(Jigsaw Puzzle)
同一图片的不同拼图块
不同图片的拼图块
时序预测(Temporal Prediction)
同一视频的相邻帧
不同视频的随机帧
掩码预测(Masked Prediction)
同一文本的不同掩码版本
不同文本的掩码版本
对比学习的目标是在特征空间中,拉近正样本对,推开负样本对 ,使相似样本的嵌入(embedding)距离更近,不相似样本的嵌入距离更远,这里的相似度通常用欧几里得距离 或余弦相似度 来度量。
对比学习通常采用InfoNCE(Noise Contrastive
Estimation) 损失函数:
\[
\mathcal{L}=-\text{log}\frac{\text{exp}(\text{sim}(q,k_+)/\tau)}{\sum_{i=0}^{K}\text{exp}(\text{sim}(q,k_i)/\tau)}
\]
\(q\) : 查询样本的特征
\(k_+\) : 正样本的特征
\(k_i\) : 负样本的特征
\(\tau\) :
温度参数,控制分布的尖锐程度
\(\text{sim}(\cdot)\) :
相似度函数
直观来看,这不就是常见的CrossEntropyLoss吗。没错,但是在对比学习领域,这个损失函数会有很大的不同。首先,在有监督学习中,K通常指训练样本的类别数,但是在对比学习中,如果使用instance
discrimination作为代理任务的话,那K将会是非常大的数字,即除该样本外其余所有样本的总和,因此,如果直接使用交叉熵损失函数,训练的时间将会非常的漫长。而NCE(Noise
Contrastive Estimation)将所有样本分为两类,data sample和noise
sample,每次训练时就将这两种样本进行对比即可。但是仅仅只做这一个改变还不行,无法解决单次训练样本数过多的问题,NCE每次仅从负样本中取出小部分的样本去做估计,这就是Estimation的含义。而InfoNCE是NCE一个变体,将二分类问题转换为多分类问题。
人脸验证(Face Verification)
人脸验证通常被另一个常见的计算机视觉任务混淆,即人脸识别。这两者密切相关,但目标不同
任务类型
人脸识别(Face Recognition)
人脸验证(Face Verification)
核心目标
一对多比对 :
判断输入人脸属于已知数据库中的哪一个个体
一对一比对 : 判断两张人脸是否属于同一个人
输出形式
输出个体身份ID
输出二元判定结果:“是同一个人”或“不是同一个人”
技术本质
多分类问题
二分类问题(相似性阈值判定)
MoCo(Momentum Contrast)
Momentum Contrast for Unsupervised Visual Representation
Learning Kaiming He et al. | arXiv 1911.05722 | Code | CVPR 2020 |
Facebook AI Research (FAIR)
MoCo使用一个动态的负样本队列,存储历史批次的特征作为负样本,可以显著增加负样本数量而无需增大batch_size。并且,MoCo将动量的思想引入编码器中,由于负样本队列中的特征来自不同时刻的编码器,参数不一致会导致噪声的干扰,而动量的引入使得负责编码负样本的编码器的参数只会缓慢随着查询编码器进行更新,能保证负样本特征的一致性
MoCo前的模型
在MoCo前有两种经典的对比学习模型:SimCLR, Memory Bank
SimCLR arXiv:2002.05709
SimCLR是一个End to
End模型,其负样本仅使用当前批次的样本作为负样本,并且对query
encoder和key
encoder都进行反向传播。这种方法虽然特征的一致性高,但是为了保证负样本特征的数量足够大,batch_size会特别大,对显存要求极高
Memory Bank arXiv:1805.01978v1
MemoryBank通过维护一个外部存储库即Memory
Bank,保存历史样本的特征:
每个样本的特征被缓存在Memory Bank中
更新时,使用旧模型参数生成新特征并替换旧值
虽然该方法支持海量的负样本,但是由于每次仅更替其中的部分负样本特征,因此特征一致性较差(旧参数生成的特征可能与新参数生成的特征在时间上差距很大)
MoCo的改变
MoCo结合的上述两种模型的思想,通过两种机制解决支持大量负样本和特征一致性问题
MoCo引入动态队列存储动量编码器生成的历史特征,通过队列的出队与入队保证了可以支持海量的负样本同时只需要较小的batch
size
引入动量机制,query encoder可以通过反向传播更新参数,而key
encoder即momentum encoder需要通过动量机制更新参数,使得key
encoder的参数只随query encoder缓慢更新,保证负样本的一致性
\[
\theta_k \leftarrow m\theta_k + (1-m)\theta_q
\]
其中:
\(\theta_k\) 是key
encoder的参数
\(\theta_q\) 是query
encoder的参数
m是动量系数,m越接近则key encoder参数随着query
encoder参数的变化越慢,原论文中作者将其设置为一个很大的数值,0.999
Loss计算上,MoCo使用InfoNCE作为损失函数
原论文给出了模型的伪代码,更方便我们理解MoCo模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 f_k.params = f_q.params for x in loader: x_q = augmentation(x) x_k = augmentation(x) q = f_q(x_q) k = f_k(x_k).detach() l_pos = torch.bmm(q.view(N, 1 , C), k.view(N, C, 1 )) l_neg = torch.mm(q.view(N, C), queue.view(C, K)) logits = torch.cat([l_pos, l_neg], dim=1 ) labels = torch.zeros(N, dtype=torch.long).cuda() loss = CrossEntropyLoss(logits / t, labels) loss.backward() update(f_q.params) f_k.params = m * f_k.params + (1 -m) * f_q.params enqueue(queue, k) dequeue(queue)
实战部分
环境配置
PyTorch 2.1.2 Python 3.10 以及对应版本的torchvision
ubuntu22.04
GPU: 单块RTX4090(24GB)
需要的库
1 2 3 4 5 6 7 8 9 10 11 12 13 14 import numpy as npimport randomimport timeimport shutilimport matplotlib.pyplot as pltimport torchfrom torch import nnfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderimport torchvision.transforms as transformsimport torchvision.datasets as datasetsfrom torchvision import modelsimport osimport glob
数据预处理
常见的人脸验证数据集有 CASIA-WebFace、Celeb等
这里采用从kaggle 上下载的CASIA-WebFace数据集进行训练,并经过筛选,最终数据集大小为211958张图片。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 root_dir = "./data/data/casia-webface/" image_root_name = [num for num in os.listdir(root_dir) if num.isdigit()] image_root_name.sort() image_dir = [root_dir + num + "/" for num in image_root_name] def show_images (image_dir ): images = glob.glob(image_dir + "*.jpeg" ) fig, axes = plt.subplots(1 , len (images), figsize=(10 , 5 )) for i, image in enumerate (images): axes[i].imshow(plt.imread(image)) axes[i].axis("off" ) plt.show() for i in range (5 ): show_images(image_dir[i])
参数设置
1 2 3 4 5 6 7 8 9 10 11 seed = 2025 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) batch_size = 256 lr = 0.03 momentum = 0.9 weight_decay = 1e-4 epochs = 200 print_freq = 10
数据增强
MoCo结合了多种数据增强策略,以生成多样化的正样本对。MoCo
v2在SimCLR的基础上优化了数据增强组合
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 import randomfrom PIL import ImageFilterclass TwoCropsTransform : def __init__ (self, base_transform ): ''' 将一张图片的两个随机crop作为query和key ''' self.base_transform = base_transform def __call__ (self, x ): q = self.base_transform(x) k = self.base_transform(x) return q, k class GaussianBlur : ''' Gaussian blur augmentation: https://arxiv.org/abs/2002.05709 ''' def __init__ (self, sigma=[0.1 , 2.0 ] ): self.sigma = sigma def __call__ (self, x ): sigma = random.uniform(self.sigma[0 ], self.sigma[1 ]) x = x.filter (ImageFilter.GaussianBlur(radius=sigma)) return x
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 import moco.loadernormalize = transforms.Normalize( mean=[0.485 , 0.456 , 0.406 ], std=[0.229 , 0.224 , 0.225 ] ) augmentation = transforms.Compose([ transforms.RandomResizedCrop(224 , scale=(0.2 , 1.0 )), transforms.RandomApply( [transforms.ColorJitter(0.4 , 0.4 , 0.4 , 0.1 )], p=0.8 ), transforms.RandomGrayscale(p=0.2 ), transforms.RandomApply([moco.loader.GaussianBlur([0.1 , 2.0 ])], p=0.5 ), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ])
生成训练数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 train_dir = root_dir train_dataset = datasets.ImageFolder( train_dir, moco.loader.TwoCropsTransform(augmentation) ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True , num_workers=4 , pin_memory=True , drop_last=True )
MoCo模型框架
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 import torchimport torch.distributedimport torch.nn as nnfrom torch.nn import functional as Fclass MoCo (nn.Module): ''' MoCo: https://arxiv.org/abs/1911.05722 ''' def __init__ (self, base_encoder, dim=128 , K=65536 , m=0.999 , T=0.07 , mlp=False ): ''' @param: base_encoder: base encoder network dim: 特征维度, feature_dimension K: queue size, 负样本的数量 m: momentum for updating key encoder T: temperature parameter mlp: 是否使用mlp ''' super (MoCo, self).__init__() self.K = K self.m = m self.T = T self.encoder_q = base_encoder(num_classes=dim) self.encoder_k = base_encoder(num_classes=dim) if mlp: dim_mlp = self.encoder_q.fc.weight.shape[1 ] self.encoder_q.fc = nn.Sequential( nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc ) self.encoder_k.fc = nn.Sequential( nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc ) for param_q, param_k in zip (self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data.copy_(param_q.data) param_k.requires_grad = False self.register_buffer("queue" , torch.randn(dim, K)) self.queue = nn.functional.normalize(self.queue, dim=0 ) self.register_buffer("queue_ptr" , torch.zeros(1 , dtype=torch.long)) @torch.no_grad() def _momentum_update_key_encoder (self ): ''' 更新key encoder ''' for param_q, param_k in zip (self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) @torch.no_grad() def _dequeue_and_enqueue (self, keys ): ''' 出队入队 ''' keys = concat_all_gather(keys) batch_size = keys.shape[0 ] ptr = int (self.queue_ptr) assert self.K % batch_size == 0 self.queue[:, ptr:ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.K self.queue_ptr[0 ] = ptr @torch.no_grad() def _batch_shuffle_single_gpu (self, x ): ''' 打乱batch ''' idx_shuffle = torch.randperm(x.shape[0 ]) idx_unshuffle = torch.empty_like(idx_shuffle, device=x.device) idx_unshuffle[idx_shuffle] = torch.arange(x.shape[0 ], device=x.device) x = x[idx_shuffle] return x, idx_unshuffle @torch.no_grad() def _batch_unshuffle_single_gpu (self, x, idx_unshuffle ): ''' 恢复batch ''' return x[idx_unshuffle] def forward (self, im_q, im_k ): ''' Input: im_q: a batch of query images im_k: a batch of key images Output: logits, targets ''' q = self.encoder_q(im_q) q = F.normalize(q, dim=1 ) with torch.no_grad(): self._momentum_update_key_encoder() im_k, idx_unshuffle = self._batch_shuffle_single_gpu(im_k) k = self.encoder_k(im_k) k = F.normalize(k, dim=1 ) k = self._batch_unshuffle_single_gpu(k, idx_unshuffle) l_pos = torch.einsum('nc,nc->n' , [q, k]).unsqueeze(-1 ) l_neg = torch.einsum('nc,ck->nk' , [q, self.queue.clone().detach()]) logits = torch.cat([l_pos, l_neg], dim=1 ) logits /= self.T labels = torch.zeros(logits.shape[0 ], dtype=torch.long).cuda() self._dequeue_and_enqueue(k) return logits, labels @torch.no_grad() def concat_all_gather (tensor ): ''' 所有进程的tensor拼接 ''' return tensor
论文源码中,作者是运行在分布式多GPU环境下,这里修改为仅在单GPU下训练
模型以及优化器
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 def base_encoder (num_classes=128 ): model = models.resnet50(pretrained=True ) in_features = model.fc.in_features model.fc = nn.Linear(in_features, num_classes) return model model = MoCo( base_encoder, dim=128 , K=65536 , m=0.999 , T=0.07 , mlp=True ) model = model.cuda() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD( model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay )
模型训练
模型训练过程中相关数据的记录
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 class AverageMeter : '''计算并记录均值和当前值''' def __init__ (self, name, fmt=":f" ): self.name = name self.fmt = fmt self.reset() def reset (self ): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update (self, val, n=1 ): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__ (self ): fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format (**self.__dict__) class ProgressMeter : def __init__ (self, num_batches, meters, prefix="" ): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display (self, batch ): entries = [self.prefix + self.batch_fmtstr.format (batch)] entries += [str (meter) for meter in self.meters] print ("\t" .join(entries)) def _get_batch_fmtstr (self, num_batches ): num_digits = len (str (num_batches // 1 )) fmt = "{:" + str (num_digits) + "d}" return "[" + fmt + "/" + fmt.format (num_batches) + "]"
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 def accuracy (output, target, topk=(1 , ) ): with torch.no_grad(): maxk = max (topk) batch_size = target.size(0 ) _, pred = output.topk(maxk, 1 , True , True ) pred = pred.t() correct = pred.eq(target.view(1 , -1 ).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1 ).float ().sum (0 , keepdim=True ) res.append(correct_k.mul_(100.0 / batch_size)) return res def adjust_learning_rate (optimizer, epoch, lr ): lr *= 0.5 * (1.0 + np.cos(np.pi * epoch / epochs)) for param_group in optimizer.param_groups: param_group["lr" ] = lr def save_checkpoint (state, is_best, filename="checkpoint.pth.tar" ): torch.save(state, filename) if is_best: shutil.copyfile(filename, "model_best.pth.tar" ) def train (train_loader, model, criterion, optimizer, epoch ): batch_time = AverageMeter("Time" , ":6.3f" ) data_time = AverageMeter("Data" , ":6.3f" ) losses = AverageMeter("Loss" , ":.4e" ) top1 = AverageMeter("Acc@1" , ":6.2f" ) top5 = AverageMeter("Acc@5" , ":6.2f" ) progress = ProgressMeter( len (train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]" .format (epoch) ) model.train() end = time.time() for i, (images, _) in enumerate (train_loader): data_time.update(time.time() - end) if torch.cuda.is_available(): images[0 ] = images[0 ].cuda(non_blocking=True ) images[1 ] = images[1 ].cuda(non_blocking=True ) output, target = model(im_q = images[0 ], im_k = images[1 ]) loss = criterion(output, target) acc1, acc5 = accuracy(output, target, topk=(1 , 5 )) losses.update(loss.item(), images[0 ].size(0 )) top1.update(acc1[0 ], images[0 ].size(0 )) top5.update(acc5[0 ], images[0 ].size(0 )) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0 : progress.display(i) for epoch in range (epochs): adjust_learning_rate(optimizer, epoch, lr) train(train_loader, model, criterion, optimizer, epoch) save_checkpoint( { "epoch" : epoch + 1 , "base_encoder" : "resnet50" , "state_dict" : model.state_dict(), "optimizer" : optimizer.state_dict() }, is_best=False , filename="checkpoint_{:04d}.pth.tar" .format (epoch) )