溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果

發(fā)布時(shí)間:2023-04-25 11:54:21 來源:億速云 閱讀:133 作者:zzz 欄目:開發(fā)技術(shù)

這篇文章主要介紹了Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果的相關(guān)知識(shí),內(nèi)容詳細(xì)易懂,操作簡(jiǎn)單快捷,具有一定借鑒價(jià)值,相信大家閱讀完這篇Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果文章都會(huì)有所收獲,下面我們一起來看看吧。

開發(fā)環(huán)境

集成開發(fā)工具:jupyter notebook 6.5.2
集成開發(fā)環(huán)境:Python 3.10.6
第三方庫:torch、matplotlib、sklearn、numpy

1 加載相關(guān)第三方庫

# 使得在 notebook 中顯示繪圖,而不是在外部窗口中顯示
%matplotlib inline  
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_s_curve
import torch
import torch.nn as nn
import io
from PIL import Image

2 加載數(shù)據(jù)集

這里選擇S 形曲線數(shù)據(jù)集作為本次復(fù)現(xiàn)擴(kuò)散模型所用數(shù)據(jù)集。

s_curve, _ = make_s_curve(10 ** 4, noise=0.1)

將數(shù)據(jù)集中的特征縮放到一個(gè)相對(duì)較小的范圍內(nèi),以便于模型的訓(xùn)練和收斂。這樣做可以避免數(shù)據(jù)的特征值之間差異過大,導(dǎo)致某些特征對(duì)模型的影響過大,而其他特征的影響被忽略的情況。同時(shí),將數(shù)據(jù)的特征縮放到一個(gè)相對(duì)較小的范圍內(nèi),也有助于提高模型的泛化能力,使其能夠更好地適應(yīng)新的未知數(shù)據(jù)。

s_curve = s_curve[:, [0, 2]] / 10.
print(F"shape of Moons:{np.shape(s_curve)}")

將數(shù)據(jù)集從原來的 (10000, 2) 轉(zhuǎn)換為 (2, 10000),即每一列對(duì)應(yīng)一個(gè)樣本的所有特征值,這樣的形狀更適合一些深度學(xué)習(xí)框架的輸入格式。同時(shí)還可以保持?jǐn)?shù)據(jù)的連續(xù)性:對(duì)數(shù)據(jù)進(jìn)行轉(zhuǎn)置操作可以保持?jǐn)?shù)據(jù)之間的連續(xù)性。在某些機(jī)器學(xué)習(xí)算法或深度學(xué)習(xí)框架中,連續(xù)的數(shù)據(jù)在內(nèi)存中存儲(chǔ)更加緊湊,可以更快地讀取和處理數(shù)據(jù),從而提高模型的訓(xùn)練和預(yù)測(cè)效率。

data = s_curve.T
 
# 繪制 S 形曲線數(shù)據(jù)集
fig, ax = plt.subplots()
ax.scatter(*data, color='red', edgecolor='white')
ax.axis('off')

因?yàn)樵谏疃葘W(xué)習(xí)中,通常使用 PyTorch 等深度學(xué)習(xí)框架來實(shí)現(xiàn)模型的訓(xùn)練和預(yù)測(cè)。而 PyTorch 中的數(shù)據(jù)處理對(duì)象是張量(Tensor),因此我們需要將原始數(shù)據(jù)集轉(zhuǎn)換為張量對(duì)象才能進(jìn)行后續(xù)的深度學(xué)習(xí)模型的訓(xùn)練和預(yù)測(cè)。另外,由于深度學(xué)習(xí)模型通常需要浮點(diǎn)數(shù)類型的數(shù)據(jù)作為輸入,因此我們需要使用 float() 將張量的數(shù)據(jù)類型設(shè)置為浮點(diǎn)型。這樣做可以保證輸入數(shù)據(jù)類型的一致性,避免數(shù)據(jù)類型不匹配導(dǎo)致的錯(cuò)誤。

dataset = torch.Tensor(s_curve).float()

3 確定超參數(shù)的值

首先,指定步數(shù)(num_step),這個(gè)步數(shù)可以根據(jù) beta、分布的均值和標(biāo)準(zhǔn)差來共同確定。num_step 指定了擴(kuò)散模型的最終狀態(tài)的計(jì)算次數(shù),每一次計(jì)算對(duì)應(yīng)一個(gè) beta 值。

接著,使用 torch.linspace() 函數(shù)生成一個(gè)等間隔的 num_step 個(gè) beta 值。然后,通過對(duì)這些 beta 值執(zhí)行 sigmoid 激活函數(shù)以及線性變換,將它們轉(zhuǎn)換為介于 1e-5 到 0.5e-2 之間的浮點(diǎn)數(shù)。這些 beta 值將在后續(xù)計(jì)算中用于計(jì)算擴(kuò)散模型的每一步的參數(shù)。

接下來,計(jì)算一些中間變量,包括 alphas、alphas_prod、alphas_prod_p、alphas_bar_sqrt、one_minus_alphas_bar_log 和 one_minus_alphas_bar_sqrt。其中,alphas 表示每一步的 alpha 值,alphas_prod 表示前 t 步的 alpha 值的累積乘積,alphas_prod_p 表示前 t-1 步的 alpha 值的累積乘積,alphas_bar_sqrt 表示前 t 步的 alpha 值的累積乘積的平方根,one_minus_alphas_bar_log 表示前 t 步的 alpha 值的累積乘積的對(duì)數(shù)的負(fù)值,one_minus_alphas_bar_sqrt 表示前 t 步的 alpha 值的累積乘積的差值的平方根。

最后,使用 assert 命令檢查計(jì)算的所有變量的形狀是否相同,并打印出 betas 變量的形狀。

num_step = 100  # 一開始可以由beta、分布的均值和標(biāo)準(zhǔn)差來共同確定
 
# 指定每一步的beta
betas = torch.linspace(-6, 6, num_step)
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
 
# 計(jì)算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等變量的值
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, dim=0)
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0)  # p表示previous
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_log = torch.log(1 - alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
 
assert alphas.shape == alphas_prod.shape == alphas_prod_p.shape == alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == one_minus_alphas_bar_sqrt.shape
print(f"all the same shape:{betas.shape}")

 4 確定擴(kuò)散過程任意時(shí)刻的采樣值

Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果

首先從正態(tài)分布中生成隨機(jī)噪聲。然后,根據(jù)參數(shù)重整化技巧,使用預(yù)先計(jì)算好的 alpha_bar_sqrt 和 one_minus_alphas_bar_sqrt,將初始值 x_0 進(jìn)行變換,得到時(shí)刻 t 的采樣值。最后,將噪聲加入到采樣值中,得到最終的采樣值。

# 計(jì)算任意時(shí)刻的x的采樣值,基于x_0和參數(shù)重整化技巧
def q_x(x_0, t):
    """可以基于x[0]得到任意時(shí)刻t的x[t]"""
    noise = torch.randn_like(x_0)  # noise是從正態(tài)分布中生成的隨機(jī)噪聲
    alphas_t = alphas_bar_sqrt[t]
    alphas_l_m_t = one_minus_alphas_bar_sqrt[t]
    return (alphas_t * x_0 + alphas_l_m_t * noise) # 在x[0]的基礎(chǔ)上添加噪聲

5 演示原始數(shù)據(jù)分布加噪100步后的效果

生成樣本點(diǎn)隨時(shí)間變化的演化過程圖。生成一個(gè)大小為2x10的子圖網(wǎng)格,每個(gè)子圖顯示了原始S曲線數(shù)據(jù)集在經(jīng)過噪聲添加和擴(kuò)散操作后在某個(gè)時(shí)間點(diǎn)t時(shí)的圖像。其中,num_shows變量指定了要顯示的時(shí)間步數(shù),這里為20,因此總共會(huì)顯示20張子圖。在每個(gè)子圖中,使用q_x函數(shù)對(duì)原始數(shù)據(jù)集進(jìn)行噪聲添加和擴(kuò)散操作,得到對(duì)應(yīng)時(shí)間點(diǎn)t時(shí)的新數(shù)據(jù)集,然后在子圖中以紅色散點(diǎn)圖的形式繪制出來。每個(gè)子圖的標(biāo)題顯示了該子圖所對(duì)應(yīng)的時(shí)間步t。

num_shows = 20
fig, axs = plt.subplots(2, 10, figsize=(28, 7))
plt.rc('text', color='blue')
# 共有10000個(gè)點(diǎn),每個(gè)點(diǎn)包含兩個(gè)坐標(biāo)
# 生成100步以內(nèi)每隔5步加噪聲后的圖像
for i in range(num_shows):
    j = i // 10
    k = i % 10
    q_i = q_x(dataset, torch.tensor([i * num_step // num_shows]))  # 生成t時(shí)刻的采樣數(shù)據(jù)
    axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white')
    
    axs[j, k].set_axis_off()
    axs[j, k].set_title('$q(\mathbf{x}_{'+ str(i * num_step // num_shows) + '})$')

Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果

6 編寫擬合逆擴(kuò)散過程高斯分布的模型

Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果

在輸入的基礎(chǔ)上添加一個(gè)時(shí)間步長(zhǎng) t,并對(duì)此進(jìn)行嵌入。具體來說,它使用了 3 個(gè) nn.Embedding 層,分別對(duì)應(yīng)于嵌入 t 的 3 個(gè)維度。

模型的 forward 方法接受一個(gè)輸入 x 和一個(gè)時(shí)間步長(zhǎng) t,并返回輸出 y。在 forward 方法中,輸入 x 會(huì)經(jīng)過一系列的全連接層(使用 nn.Linear 實(shí)現(xiàn)),其中每?jī)蓚€(gè)全連接層之間都有一個(gè) ReLU 激活函數(shù)。在這些全連接層之前和之后,模型都會(huì)使用 nn.Embedding 層將 t 嵌入到向量中。最終的輸出 y 是一個(gè) 2 維向量。

class MLPDiffusion(nn.Module):
    def __init__(self, n_steps, num_units=128):
        super(MLPDiffusion, self).__init__()
        
        self.linears = nn.ModuleList(
        [
            nn.Linear(2, num_units),
            nn.ReLU(),
            nn.Linear(num_units, num_units),
            nn.ReLU(),
            nn.Linear(num_units, num_units),
            nn.ReLU(),
            nn.Linear(num_units, 2),
        ])
        
        self.step_embeddings = nn.ModuleList(
        [
            nn.Embedding(n_steps, num_units),
            nn.Embedding(n_steps, num_units),
            nn.Embedding(n_steps, num_units),
        ])
    
    def forward(self, x, t):
        for idx, embedding_layer in enumerate(self.step_embeddings):
            t_embedding = embedding_layer(t)
            x = self.linears[2 * idx](x)
            x += t_embedding
            x = self.linears[2 * idx + 1](x)
        
        x = self.linears[-1](x)
        return x

7 編寫訓(xùn)練的誤差函數(shù)

Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果

def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps):
    """
    對(duì)任意時(shí)刻t進(jìn)行采樣計(jì)算loss
    param:
    model:模型 
    x_0:初始狀態(tài) 
    alphas_bar_sqrt、one_minus_alphas_bar_sqrt: 參數(shù) 
    n_steps:時(shí)間步數(shù)
    return:損失值
    """
    batch_size = x_0.shape[0]
    # 隨機(jī)采樣一個(gè)時(shí)刻t,為了提高訓(xùn)練效率,這里確保t不重復(fù)
    # 對(duì)一個(gè)batchsize樣本生成隨機(jī)的時(shí)刻t,覆蓋到更多不同的t
    t = torch.randint(0, n_steps, size=(batch_size // 2,))
    t = torch.cat([t, n_steps - 1 - t], dim=0)  # [batch]
    t = t.unsqueeze(-1)  # [batch, 1]
    # x0的系數(shù)
    a = alphas_bar_sqrt[t]
    # eps的系數(shù)
    aml = one_minus_alphas_bar_sqrt[t]
    # 生成隨機(jī)噪聲eps
    e = torch.randn_like(x_0)
    # 構(gòu)造模型的輸入
    x = x_0 * a + e * aml
    # 送入模型,得到t時(shí)刻的隨機(jī)噪聲預(yù)測(cè)值
    output = model(x, t.squeeze(-1))
    # 與真實(shí)噪聲一起計(jì)算誤差,求平均值
    return (e - output).square().mean()

8 編寫逆擴(kuò)散采樣函數(shù)(inference過程)

進(jìn)行擴(kuò)散模型的采樣。具體來說,p_sample_loop函數(shù)是從x[T]恢復(fù)x[T-1]、x[T-2]、...、x[0]的過程,其中x[T]是輸入的初始值。在這個(gè)函數(shù)里,使用了一個(gè)for循環(huán),從最后一個(gè)時(shí)刻T開始往前推,依次對(duì)每個(gè)時(shí)刻進(jìn)行采樣。在每個(gè)時(shí)刻,調(diào)用p_sample函數(shù)進(jìn)行采樣。

p_sample函數(shù)的主要作用是從x[T]采樣t時(shí)刻的重構(gòu)值,其中x[T]是輸入的初始值,t表示當(dāng)前時(shí)刻。具體來說,首先通過模型預(yù)測(cè)出eps_theta,然后通過一些計(jì)算,得到該時(shí)刻的重構(gòu)值sample。其中,mean表示重構(gòu)值的均值,z是服從標(biāo)準(zhǔn)正態(tài)分布的噪聲,sigma_t是該時(shí)刻的標(biāo)準(zhǔn)差。最后,將sample作為當(dāng)前時(shí)刻的重構(gòu)值返回。

def p_sample_loop(model, shape, n_step, betas, one_minus_alphas_bar_sqrt):
    """從x[T]恢復(fù)x[T - 1]、x[T - 2]、...、x[0]"""
    cur_x = torch.randn(shape)
    x_seq = [cur_x]
    for i in reversed(range(n_step)):
        cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt)
        x_seq.append(cur_x)
    return x_seq
 
def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt):
    """從x[T]采樣t時(shí)刻的重構(gòu)值"""
    t = torch.tensor([t])
    coeff = betas[t] / one_minus_alphas_bar_sqrt[t]
    eps_theta = model(x, t)
    mean = (1 / (1 - betas[t]).sqrt()) * (x - (coeff * eps_theta))
    z = torch.randn_like(x)
    sigma_t = betas[t].sqrt()
    sample = mean + sigma_t * z
    return (sample)

9 開始訓(xùn)練模型,并打印loss及中間重構(gòu)效果

這段代碼定義了一個(gè)EMA(Exponential Moving Average,指數(shù)平滑移動(dòng)平均)類,它用于對(duì)模型的參數(shù)進(jìn)行平滑處理。構(gòu)造函數(shù)中的 mu 參數(shù)控制平滑程度,shadow 是一個(gè)字典,用于存儲(chǔ)參數(shù)的平滑后的值。

register 方法將參數(shù) val 注冊(cè)到 shadow 字典中,__call__方法對(duì)指定名稱的參數(shù) name 進(jìn)行平滑處理。其中,x 是當(dāng)前時(shí)刻參數(shù)的值。計(jì)算完成后,將結(jié)果存儲(chǔ)在 shadow 字典中,并返回平滑后的值。

seed = 1234  # 確保程序在每次運(yùn)行時(shí)生成的隨機(jī)數(shù)序列都是一樣的
 
class EMA():
    """構(gòu)建一個(gè)參數(shù)平滑器,以便更好地泛化模型并減少過擬合"""
    def __init__(self, mu=0.01):
        self.mu = mu
        self.shadow = {}
        
    def register(self, name, val):
        self.shadow[name] = val.clone()
        
    def __call__(self, name, x):
        assert name in self.shadow
        new_average = self.mu * x + (1.0 - self.mu) * self.shadow[name]
        return new_average
print('Training model.....')
 
 
batch_size = 512  # 批訓(xùn)練大小
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
num_epoch = 4000  # 定義迭代4000次
plt.rc('text', color='blue')
 
model = MLPDiffusion(num_step)  # 輸出維度是2,輸入是x和step
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 
for t in range(num_epoch):
    for idx, batch_x in enumerate(dataloader):
        loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_step)
        optimizer.zero_grad()  # 對(duì)梯度進(jìn)行清零,防止網(wǎng)絡(luò)權(quán)重更新過于迅速或不穩(wěn)定,無法得到正確的收斂結(jié)果
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.)  # 對(duì)梯度進(jìn)行裁剪,避免出現(xiàn)梯度爆炸
        optimizer.step()
    if (t % 100 == 0):
        print(loss)
        x_seq = p_sample_loop(model, dataset.shape, num_step, betas, one_minus_alphas_bar_sqrt)  # 共有100個(gè)元素
        
        fig, axs = plt.subplots(1, 5, figsize=(28, 7))
        for i in range(1, 6):
            cur_x = x_seq[i * 20].detach()
            axs[i - 1].scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white')
            axs[i - 1].set_axis_off()
            axs[i - 1].set_title('$q(\mathbf{x}_{'+str(i * 20)+'})$')

部分效果圖:

Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果

10 動(dòng)畫演示擴(kuò)散過程和逆擴(kuò)散過程

# 生成前向過程,也就是逐步加噪聲
imgs = []
for i in range(100):
    plt.clf()
    torch_i = q_x(dataset, torch.tensor([i]))
    plt.scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white', s=5)
    plt.axis('off')
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png')
    img = Image.open(img_buf)
    imgs.append(img)

Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果

# 生成逆過程,也就是逐步復(fù)原
reverse = []
for i in range(100):
    plt.clf()
    cur_x = x_seq[i].detach()  # 拿到訓(xùn)練末尾階段生成的x_seq
    plt.scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white', s=5)
    plt.axis('off')
    
    img_buf = io.BytesIO()
    plt.savefig(img_buf, format='png')
    img = Image.open(img_buf)
    reverse.append(img)

Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果

imgs = imgs + reverse
 
imgs[0].save("diffusion.gif", format='GIF', append_images=imgs, save_all=True, duration=100, loop=0)

動(dòng)畫效果圖:

Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果

關(guān)于“Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果”這篇文章的內(nèi)容就介紹到這里,感謝各位的閱讀!相信大家對(duì)“Pytorch怎么實(shí)現(xiàn)擴(kuò)散模型效果”知識(shí)都有一定的了解,大家如果還想學(xué)習(xí)更多知識(shí),歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI