溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

如何理解Python中的pyTorch權(quán)重衰減與L2范數(shù)正則化

發(fā)布時(shí)間:2021-09-30 17:12:22 來源:億速云 閱讀:162 作者:iii 欄目:開發(fā)技術(shù)

這篇文章主要介紹“如何理解Python中的pyTorch權(quán)重衰減與L2范數(shù)正則化”,在日常操作中,相信很多人在如何理解Python中的pyTorch權(quán)重衰減與L2范數(shù)正則化問題上存在疑惑,小編查閱了各式資料,整理出簡單好用的操作方法,希望對大家解答”如何理解Python中的pyTorch權(quán)重衰減與L2范數(shù)正則化”的疑惑有所幫助!接下來,請跟著小編一起來學(xué)習(xí)吧!

如何理解Python中的pyTorch權(quán)重衰減與L2范數(shù)正則化

下面進(jìn)行一個(gè)高維線性實(shí)驗(yàn)

假設(shè)我們的真實(shí)方程是:

如何理解Python中的pyTorch權(quán)重衰減與L2范數(shù)正則化

假設(shè)feature數(shù)200,訓(xùn)練樣本和測試樣本各20個(gè)

模擬數(shù)據(jù)集

num_train,num_test = 10,10
num_features = 200
true_w = torch.ones((num_features,1),dtype=torch.float32) * 0.01
true_b = torch.tensor(0.5)
samples = torch.normal(0,1,(num_train+num_test,num_features))
noise = torch.normal(0,0.01,(num_train+num_test,1))
labels = samples.matmul(true_w) + true_b + noise
train_samples, train_labels= samples[:num_train],labels[:num_train]
test_samples, test_labels = samples[num_train:],labels[num_train:]

定義帶正則項(xiàng)的loss function

def loss_function(predict,label,w,lambd):
    loss = (predict - label) ** 2
    loss = loss.mean() + lambd * (w**2).mean()
    return loss

畫圖的方法

def semilogy(x_val,y_val,x_label,y_label,x2_val,y2_val,legend):
    plt.figure(figsize=(3,3))
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.semilogy(x_val,y_val)
    if x2_val and y2_val:
        plt.semilogy(x2_val,y2_val)
        plt.legend(legend)
    plt.show()

擬合和畫圖

def fit_and_plot(train_samples,train_labels,test_samples,test_labels,num_epoch,lambd):
    w = torch.normal(0,1,(train_samples.shape[-1],1),requires_grad=True)
    b = torch.tensor(0.,requires_grad=True)
    optimizer = torch.optim.Adam([w,b],lr=0.05)
    train_loss = []
    test_loss = []
    for epoch in range(num_epoch):
        predict = train_samples.matmul(w) + b
        epoch_train_loss = loss_function(predict,train_labels,w,lambd)
        optimizer.zero_grad()
        epoch_train_loss.backward()
        optimizer.step()
        test_predict = test_sapmles.matmul(w) + b
        epoch_test_loss = loss_function(test_predict,test_labels,w,lambd)
        train_loss.append(epoch_train_loss.item())
        test_loss.append(epoch_test_loss.item())
    semilogy(range(1,num_epoch+1),train_loss,'epoch','loss',range(1,num_epoch+1),test_loss,['train','test'])

如何理解Python中的pyTorch權(quán)重衰減與L2范數(shù)正則化
可以發(fā)現(xiàn)加了正則項(xiàng)的模型,在測試集上的loss確實(shí)下降了

到此,關(guān)于“如何理解Python中的pyTorch權(quán)重衰減與L2范數(shù)正則化”的學(xué)習(xí)就結(jié)束了,希望能夠解決大家的疑惑。理論與實(shí)踐的搭配能更好的幫助大家學(xué)習(xí),快去試試吧!若想繼續(xù)學(xué)習(xí)更多相關(guān)知識,請繼續(xù)關(guān)注億速云網(wǎng)站,小編會(huì)繼續(xù)努力為大家?guī)砀鄬?shí)用的文章!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI