您好,登錄后才能下訂單哦!
這篇文章給大家介紹如何利用PyTorch中的Moco-V2減少計算約束,內(nèi)容非常詳細,感興趣的小伙伴們可以參考借鑒,希望對大家能有所幫助。
SimCLR論文(http://cse.iitkgp.ac.in/~arastogi/papers/simclr.pdf)解釋了這個框架如何從更大的模型和更大的批處理中獲益,并且如果有足夠的計算能力,可以產(chǎn)生與監(jiān)督模型類似的結果。
但是這些需求使得框架的計算量相當大。如果我們可以擁有這個框架的簡單性和強大功能,并且有更少的計算需求,這樣每個人都可以訪問它,這不是很好嗎?Moco-v2前來救援。
這次我們將在Pytorch中在更大的數(shù)據(jù)集上實現(xiàn)Moco-v2,并在Google Colab上訓練我們的模型。這次我們將使用Imagenette和Imagewoof數(shù)據(jù)集
來自Imagenette數(shù)據(jù)集的一些圖像
這些數(shù)據(jù)集的快速摘要(更多信息在這里:https://github.com/fastai/imagenette):
Imagenette由Imagenet的10個容易分類的類組成,總共有9479個訓練圖像和3935個驗證集圖像。
Imagewoof是一個由Imagenet提供的10個難分類組成的數(shù)據(jù)集,因為所有的類都是狗的品種??偣灿?035個訓練圖像,3939個驗證集圖像。
對比學習在自我監(jiān)督學習中的作用是基于這樣一個理念:我們希望同一類別中不同的圖像觀具有相似的表征。但是,由于我們不知道哪些圖像屬于同一類別,通常所做的是將同一圖像的不同外觀的表示拉近。我們把這些不同的外觀稱為正對(positive pairs)。
另外,我們希望不同類別的圖像有不同的外觀,使它們的表征彼此遠離。不同圖像的不同外觀的呈現(xiàn)與類別無關,會被彼此推開。我們把這些不同的外觀稱為負對(negative pairs)。
在這種情況下,一個圖像的前景是什么?前景可以被認為是以一種經(jīng)過修改的方式看待圖像的某些部分,它本質(zhì)上是圖像的一種變換。
根據(jù)手頭的任務,有些轉換可以比其他轉換工作得更好。SimCLR表明,應用隨機裁剪和顏色抖動可以很好地完成各種任務,包括圖像分類。這本質(zhì)上來自于網(wǎng)格搜索,從旋轉、裁剪、剪切、噪聲、模糊、Sobel濾波等選項中選擇一對變換。
從外觀到表示空間的映射是通過神經(jīng)網(wǎng)絡完成的,通常,resnet用于此目的。下面是從圖像到表示的管道
在同一幅圖像中,由于隨機裁剪,我們可以得到多個表示。這樣,我們就可以產(chǎn)生正對。
但是如何生成負對呢?負對是來自不同圖像的表示。SimCLR論文在同一批中創(chuàng)建了這些。如果一個批包含N個圖像,那么對于每個圖像,我們將得到2個表示,這總共占2*N個表示。對于一個特定的表示x,有一個表示與x形成正對(與x來自同一個圖像的表示),其余所有表示(正好是2*N–2)與x形成負對。
如果我們手頭有大量的負樣本,這些表示就會得到改善。但是,在SimCLR中,只有當批量較大時,才能實現(xiàn)大量的負樣本,這導致了對計算能力的更高要求。MoCo-v2提供了生成負樣本的另一種方法。讓我們詳細了解一下。
我們可以用一種稍微不同的方式來看待對比學習方法,即將查詢與鍵進行匹配。我們現(xiàn)在有兩個編碼器,一個用于查詢,另一個用于鍵。此外,為了得到大量的負樣本,我們需要一個大的鍵編碼字典。
此上下文中的正對表示查詢與鍵匹配。如果查詢和鍵都來自同一個圖像,則它們匹配。編碼的查詢應該與其匹配的鍵相似,而與其他查詢不同。
對于負對,我們維護一個大字典,其中包含以前批處理的編碼鍵。它們作為查詢的負樣本。我們以隊列的形式維護字典。新的batch被入隊,較早的batch被出列。通過更改此隊列的大小,可以更改負采樣數(shù)。
隨著鍵編碼器的更改,在稍后時間點排隊的鍵可能與較早排隊的鍵不一致。為了使用對比學習方法,與查詢進行比較的所有鍵必須來自相同或相似的編碼器,這樣比較才會有意義且一致。
另一個挑戰(zhàn)是,使用反向傳播學習編碼器參數(shù)是不可行的,因為這將需要計算隊列中所有樣本的梯度(這將導致大的計算圖)。
為了解決這兩個問題,MoCo將鍵編碼器實現(xiàn)為基于動量的查詢編碼器的移動平均值[1]。這意味著它以這種方式更新關鍵編碼器參數(shù):
其中m非常接近于1(例如,典型值為0.999),這確保我們在不同的時間從相似的編碼器獲得編碼鍵。
我們希望查詢接近其所有正樣本,遠離所有負樣本。InfoNC函數(shù)E會捕獲它。它代表信息噪聲對比估計。對于查詢q和鍵k,InfoNCE損失函數(shù)是:
我們可以重寫為:
當q和k的相似性增大,q與負樣本的相似性減小時,損失值減小
以下是損失函數(shù)的代碼:
τ = 0.05 def loss_function(q, k, queue): # N是批量大小 N = q.shape[0] # C是表示的維數(shù) C = q.shape[1] # bmm代表批處理矩陣乘法 # 如果mat1是b×n×m張量,那么mat2是b×m×p張量, # 然后輸出一個b×n×p張量。 pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),τ)) # 在查詢和隊列張量之間執(zhí)行矩陣乘法 neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),τ)), dim=1) # 求和 denominator = neg + pos return torch.mean(-torch.log(torch.div(pos,denominator)))
讓我們再看看這個損失函數(shù),并將它與分類交叉熵損失函數(shù)進行比較。
這里pred?是數(shù)據(jù)點在第i類中的概率值預測,true?是該點屬于第i類的實際概率值(可以是模糊的,但大多數(shù)情況下是一個one-hot)。
如果你不熟悉這個話題,你可以看這個視頻來更好地理解交叉熵。另外,請注意,我們經(jīng)常通過softmax這樣的函數(shù)將分數(shù)轉換為概率值:https://www.youtube.com/watch?v=ErfnhcEV1O8
我們可以把信息損失函數(shù)看作交叉熵損失。數(shù)據(jù)樣本“q”的正確類是第r類,底層分類器基于softmax,它試圖在K+1類之間進行分類。
Info-NCE還與編碼表示之間的相互信息有關;關于這一點的更多細節(jié)見[4]。
現(xiàn)在,讓我們把所有的東西放在一起,看看整個Moco-v2算法是什么樣子的。
我們必須得到查詢和鍵編碼器。最初,鍵編碼器具有與查詢編碼器相同的參數(shù)。它們是彼此的復制品。隨著訓練的進行,鍵編碼器將成為查詢編碼器的移動平均值(在這一點上進展緩慢)。
由于計算能力的限制,我們使用Resnet-18體系結構來實現(xiàn)。在通常的resnet架構之上,我們添加了一些密集的層,以使表示的維數(shù)降到25。這些層中的某些層稍后將充當投影。
# 定義我們的深度學習架構 resnetq = resnet18(pretrained=False) classifier = nn.Sequential(OrderedDict([ ('fc1', nn.Linear(resnetq.fc.in_features, 100)), ('added_relu1', nn.ReLU(inplace=True)), ('fc2', nn.Linear(100, 50)), ('added_relu2', nn.ReLU(inplace=True)), ('fc3', nn.Linear(50, 25)) ])) resnetq.fc = classifier resnetk = copy.deepcopy(resnetq) # 將resnet架構遷移到設備 resnetq.to(device) resnetk.to(device)
現(xiàn)在,我們已經(jīng)有了編碼器,并且假設我們已經(jīng)設置了其他重要的數(shù)據(jù)結構,現(xiàn)在是時候開始訓練循環(huán)并理解管道了。
這一步是從訓練批中獲取編碼查詢和鍵。我們用L2范數(shù)對表示進行規(guī)范化。
只是一個約定警告,所有后續(xù)步驟中的代碼都將位于批處理和epoch循環(huán)中。我們還將張量“k”從它的梯度中分離出來,因為我們不需要計算圖中的鍵編碼器部分,因為動量更新方程會更新鍵編碼器。
# 梯度零化 optimizer.zero_grad() # 檢索xq和xk這兩個圖像batch xq = sample_batched['image1'] xk = sample_batched['image2'] # 把它們移到設備上 xq = xq.to(device) xk = xk.to(device) # 獲取他們的輸出 q = resnetq(xq) k = resnetk(xk) k = k.detach() # 將輸出規(guī)范化,使它們成為單位向量 q = torch.div(q,torch.norm(q,dim=1).reshape(-1,1)) k = torch.div(k,torch.norm(k,dim=1).reshape(-1,1))
現(xiàn)在,我們將查詢、鍵和隊列傳遞給前面定義的loss函數(shù),并將值存儲在一個列表中。然后,像往常一樣,對損失值調(diào)用backward函數(shù)并運行優(yōu)化器。
# 獲得損失值 loss = loss_function(q, k, queue) # 把這個損失值放到epoch損失列表中 epoch_losses_train.append(loss.cpu().data.item()) # 反向傳播 loss.backward() # 運行優(yōu)化器 optimizer.step()
我們將最新的batch加入我們的隊列。如果我們的隊列大小大于我們定義的最大隊列大小(K),那么我們就從其中取出最老的batch??梢允褂胻orch.cat進行隊列操作。
# 更新隊列 queue = torch.cat((queue, k), 0) # 如果隊列大于最大隊列大?。╧),則出列 # batch大小是256,可以用變量替換 if queue.shape[0] > K: queue = queue[256:,:]
現(xiàn)在我們進入訓練循環(huán)的最后一步,即更新鍵編碼器。我們使用下面的for循環(huán)來實現(xiàn)這一點。
# 更新resnet for θ_k, θ_q in zip(resnetk.parameters(), resnetq.parameters()): θ_k.data.copy_(momentum*θ_k.data + θ_q.data*(1.0 - momentum))
訓練resnet-18模型的Imagenette和Imagewoof數(shù)據(jù)集的GPU時間接近18小時。為此,我們使用了googlecolab的GPU(16GB)。我們使用的batch大小為256,tau值為0.05,學習率為0.001,最終降低到1e-5,權重衰減為1e-6。我們的隊列大小為8192,鍵編碼器的動量值為0.999。
前3層(將relu視為一層)定義了投影頭,我們將其移除用于圖像分類的下游任務。在剩下的網(wǎng)絡上,我們訓練了一個線性分類器。
我們得到了64.2%的正確率,而使用10%的標記訓練數(shù)據(jù),使用MoCo-v2。相比之下,使用最先進的監(jiān)督學習方法,其準確率接近95%。
對于Imagewoof,我們對10%的標記數(shù)據(jù)得到了38.6%的準確率。在這個數(shù)據(jù)集上進行對比學習的效果低于我們的預期。我們懷疑這是因為首先,數(shù)據(jù)集非常困難,因為所有類都是狗類。
其次,我們認為顏色是這些類的一個重要的區(qū)別特征。應用顏色抖動可能會導致來自不同類的多個圖像彼此混合表示。相比之下,監(jiān)督方法的準確率接近90%。
能夠彌合自監(jiān)督模型和監(jiān)督模型之間差距的設計變更:
使用更大更寬的模型。
通過使用更大的批量和字典大小。
使用更多的數(shù)據(jù),如果可以的話。同時引入所有未標記的數(shù)據(jù)。
在大量數(shù)據(jù)上訓練大型模型,然后提取它們。
關于如何利用PyTorch中的Moco-V2減少計算約束就分享到這里了,希望以上內(nèi)容可以對大家有一定的幫助,可以學到更多知識。如果覺得文章不錯,可以把它分享出去讓更多的人看到。
免責聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權內(nèi)容。