溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Python?Pytorch圖像檢索實(shí)例分析

發(fā)布時(shí)間:2022-04-08 16:50:23 來(lái)源:億速云 閱讀:301 作者:iii 欄目:開發(fā)技術(shù)

這篇文章主要介紹“Python Pytorch圖像檢索實(shí)例分析”,在日常操作中,相信很多人在Python Pytorch圖像檢索實(shí)例分析問題上存在疑惑,小編查閱了各式資料,整理出簡(jiǎn)單好用的操作方法,希望對(duì)大家解答”Python Pytorch圖像檢索實(shí)例分析”的疑惑有所幫助!接下來(lái),請(qǐng)跟著小編一起來(lái)學(xué)習(xí)吧!

背景

圖像檢索的基本本質(zhì)是根據(jù)查詢圖像的特征從集合或數(shù)據(jù)庫(kù)中查找圖像。

大多數(shù)情況下,這種特征是圖像之間簡(jiǎn)單的視覺相似性。在一個(gè)復(fù)雜的問題中,這種特征可能是兩幅圖像在風(fēng)格上的相似性,甚至是互補(bǔ)性。

由于原始形式的圖像不會(huì)在基于像素的數(shù)據(jù)中反映這些特征,因此我們需要將這些像素?cái)?shù)據(jù)轉(zhuǎn)換為一個(gè)潛空間,在該空間中,圖像的表示將反映這些特征。

一般來(lái)說,在潛空間中,任何兩個(gè)相似的圖像都會(huì)相互靠近,而不同的圖像則會(huì)相隔很遠(yuǎn)。這是我們用來(lái)訓(xùn)練我們的模型的基本管理規(guī)則。一旦我們這樣做,檢索部分只需搜索潛在空間,在給定查詢圖像表示的潛在空間中拾取最近的圖像。大多數(shù)情況下,它是在最近鄰搜索的幫助下完成的。

因此,我們可以將我們的方法分為兩部分:

  • 圖像表現(xiàn)

  • 搜索

我們將在Oxford 102 Flowers數(shù)據(jù)集上解決這兩個(gè)部分。

圖像表現(xiàn)

我們將使用一種叫做暹羅模型的東西,它本身并不是一種全新的模型,而是一種訓(xùn)練模型的技術(shù)。大多數(shù)情況下,這是與triplet loss一起使用的。這個(gè)技術(shù)的基本組成部分是三元組。

三元組是3個(gè)獨(dú)立的數(shù)據(jù)樣本,比如A(錨點(diǎn)),B(陽(yáng)性)和C(陰性);其中A和B相似或具有相似的特征(可能是同一類),而C與A和B都不相似。這三個(gè)樣本共同構(gòu)成了訓(xùn)練數(shù)據(jù)的一個(gè)單元——三元組。

注:任何圖像檢索任務(wù)的90%都體現(xiàn)在暹羅網(wǎng)絡(luò)、triplet loss和三元組的創(chuàng)建中。如果你成功地完成了這些,那么整個(gè)努力的成功或多或少是有保證的。

首先,我們將創(chuàng)建管道的這個(gè)組件——數(shù)據(jù)。下面我們將在PyTorch中創(chuàng)建一個(gè)自定義數(shù)據(jù)集和數(shù)據(jù)加載器,它將從數(shù)據(jù)集中生成三元組。

class TripletData(Dataset):
    def __init__(self, path, transforms, split="train"):
 
        self.path = path
        self.split = split    # train or valid
        self.cats = 102       # number of categories
        self.transforms = transforms
 
        
    def __getitem__(self, idx):
 
        # our positive class for the triplet
        idx = str(idx%self.cats + 1)
 
        # choosing our pair of positive images (im1, im2)
        positives = os.listdir(os.path.join(self.path, idx))
        im1, im2 = random.sample(positives, 2)
 
        # choosing a negative class and negative image (im3)
        negative_cats = [str(x+1) for x in range(self.cats)]
        negative_cats.remove(idx)
        negative_cat = str(random.choice(negative_cats))
        negatives = os.listdir(os.path.join(self.path, negative_cat))
 
        im3 = random.choice(negatives)
 
        im1,im2,im3 = os.path.join(self.path, idx, im1), os.path.join(self.path, idx, im2), os.path.join(self.path, negative_cat, im3)
 
        im1 = self.transforms(Image.open(im1))
 
        im2 = self.transforms(Image.open(im2))
 
        im3 = self.transforms(Image.open(im3))
 
        return [im1, im2, im3]
 
    
    # we'll put some value that we want since there can be far too many triplets possible
    # multiples of the number of images/ number of categories is a good choice
    def __len__(self):
        return self.cats*8
# Transforms
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# Datasets and Dataloaders
train_data = TripletData(PATH_TRAIN, train_transforms)
val_data = TripletData(PATH_VALID, val_transforms)
train_loader = torch.utils.data.DataLoader(dataset = train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(dataset = val_data, batch_size=32, shuffle=False, num_workers=4)

現(xiàn)在我們有了數(shù)據(jù),讓我們轉(zhuǎn)到暹羅網(wǎng)絡(luò)。

暹羅網(wǎng)絡(luò)給人的印象是2個(gè)或3個(gè)模型,但是它本身是一個(gè)單一的模型。所有這些模型共享權(quán)重,即只有一個(gè)模型。

Python?Pytorch圖像檢索實(shí)例分析

如前所述,將整個(gè)體系結(jié)構(gòu)結(jié)合在一起的關(guān)鍵因素是triplet loss。triplet loss產(chǎn)生了一個(gè)目標(biāo)函數(shù),該函數(shù)迫使相似輸入對(duì)(錨點(diǎn)和正)之間的距離小于不同輸入對(duì)(錨點(diǎn)和負(fù))之間的距離,并限定一定的閾值。

下面我們來(lái)看看triplet loss以及訓(xùn)練管道實(shí)現(xiàn)。

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
# Training
for epoch in range(epochs):
    
    model.train()
    epoch_loss = 0.0
    
    for data in tqdm(train_loader):
        
        optimizer.zero_grad()
        x1,x2,x3 = data
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))
 
    
    
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        
        super(TripletLoss, self).__init__()
        self.margin = margin
        
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    
    # Distances in embedding space is calculated in euclidean
    def forward(self, anchor, positive, negative):
        
        distance_positive = self.calc_euclidean(anchor, positive)
        
        distance_negative = self.calc_euclidean(anchor, negative)
        
        losses = torch.relu(distance_positive - distance_negative + self.margin)
        
        return losses.mean()
      
 
device = 'cuda'
 
 
# Our base model
model = models.resnet18().cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletLoss()
 
 
# Training
for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    for data in tqdm(train_loader):
 
        optimizer.zero_grad()
        
        x1,x2,x3 = data
        
        e1 = model(x1.to(device))
        e2 = model(x2.to(device))
        e3 = model(x3.to(device)) 
        
        loss = triplet_loss(e1,e2,e3)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
        
    print("Train Loss: {}".format(epoch_loss.item()))

到目前為止,我們的模型已經(jīng)經(jīng)過訓(xùn)練,可以將圖像轉(zhuǎn)換為一個(gè)嵌入空間。接下來(lái),我們進(jìn)入搜索部分。

搜索

我們可以很容易地使用Scikit Learn提供的最近鄰搜索。我們將探索新的更好的東西,而不是走簡(jiǎn)單的路線。

我們將使用Faiss。這比最近的鄰居要快得多,如果我們有大量的圖像,這種速度上的差異會(huì)變得更加明顯。

下面我們將演示如何在給定查詢圖像時(shí),在存儲(chǔ)的圖像表示中搜索最近的圖像。

#!pip install faiss-gpu
import faiss                            
faiss_index = faiss.IndexFlatL2(1000)   # build the index
 
# storing the image representations
im_indices = []
 
with torch.no_grad():
    for f in glob.glob(os.path.join(PATH_TRAIN, '*/*')):
        
        im = Image.open(f)
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        preds = model(im)
        preds = np.array([preds[0].cpu().numpy()])
        faiss_index.add(preds) #add the representation to index
        im_indices.append(f)   #store the image name to find it later on
 
        
# Retrieval with a query image
with torch.no_grad():
    for f in os.listdir(PATH_TEST):
        
        # query/test image
        im = Image.open(os.path.join(PATH_TEST,f))
        im = im.resize((224,224))
        im = torch.tensor([val_transforms(im).numpy()]).cuda()
    
        test_embed = model(im).cpu().numpy()
        
        _, I = faiss_index.search(test_embed, 5)
        print("Retrieved Image: {}".format(im_indices[I[0][0]]))

這涵蓋了基于現(xiàn)代深度學(xué)習(xí)的圖像檢索,但不會(huì)使其變得太復(fù)雜。大多數(shù)檢索問題都可以通過這個(gè)基本管道解決。

到此,關(guān)于“Python Pytorch圖像檢索實(shí)例分析”的學(xué)習(xí)就結(jié)束了,希望能夠解決大家的疑惑。理論與實(shí)踐的搭配能更好的幫助大家學(xué)習(xí),快去試試吧!若想繼續(xù)學(xué)習(xí)更多相關(guān)知識(shí),請(qǐng)繼續(xù)關(guān)注億速云網(wǎng)站,小編會(huì)繼續(xù)努力為大家?guī)?lái)更多實(shí)用的文章!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI