溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Pytorch中的torch.distributions庫(kù)怎么使用

發(fā)布時(shí)間:2023-02-24 11:41:17 來(lái)源:億速云 閱讀:92 作者:iii 欄目:開發(fā)技術(shù)

本文小編為大家詳細(xì)介紹“Pytorch中的torch.distributions庫(kù)怎么使用”,內(nèi)容詳細(xì),步驟清晰,細(xì)節(jié)處理妥當(dāng),希望這篇“Pytorch中的torch.distributions庫(kù)怎么使用”文章能幫助大家解決疑惑,下面跟著小編的思路慢慢深入,一起來(lái)學(xué)習(xí)新知識(shí)吧。

Pytorch torch.distributions庫(kù)

包介紹

torch.distributions包包含可參數(shù)化的概率分布和采樣函數(shù)。 這允許構(gòu)建用于優(yōu)化的隨機(jī)計(jì)算圖和隨機(jī)梯度估計(jì)器。

不可能通過(guò)隨機(jī)樣本直接反向傳播。 但是,有兩種主要方法可以創(chuàng)建可以反向傳播的代理函數(shù)。

這些是

評(píng)分函數(shù)估計(jì)量 score function estimato
似然比估計(jì)量 likelihood ratio estimator
REINFORCE
路徑導(dǎo)數(shù)估計(jì)量 pathwise derivative estimator
REINFORCE 通常被視為強(qiáng)化學(xué)習(xí)中策略梯度方法的基礎(chǔ),

路徑導(dǎo)數(shù)估計(jì)器常見于變分自編碼器的重新參數(shù)化技巧中。

雖然評(píng)分函數(shù)只需要樣本 f(x)的值,但路徑導(dǎo)數(shù)需要導(dǎo)數(shù) f'(x)。

本文重點(diǎn)講解Pytorch中的 torch.distributions庫(kù)。

pytorch 的 torch.distributions 中可以定義正態(tài)分布:

import torch
from torch.distributions import  Normal
mean=torch.Tensor([0,2])
normal=Normal(mean,1)

sample()就是直接在定義的正太分布(均值為mean,標(biāo)準(zhǔn)差std是1)上采樣:

result = normal.sample()
print("sample():",result)

輸出:

sample(): tensor([-1.3362,  3.1730])

rsample()不是在定義的正太分布上采樣,而是先對(duì)標(biāo)準(zhǔn)正太分布 N(0,1) 進(jìn)行采樣,然后輸出: mean + std × 采樣值

result = normal.rsample()
print("rsample():",result)

輸出:

rsample: tensor([ 0.0530,  2.8396])

log_prob(value) 是計(jì)算value在定義的正態(tài)分布(mean,1)中對(duì)應(yīng)的概率的對(duì)數(shù),正太分布概率密度函數(shù)是:

Pytorch中的torch.distributions庫(kù)怎么使用

對(duì)其取對(duì)數(shù)可得:

Pytorch中的torch.distributions庫(kù)怎么使用

這里我們通過(guò)對(duì)數(shù)概率還原其對(duì)應(yīng)的真實(shí)概率:

print("result log_prob:",normal.log_prob(result).exp())

輸出:

result log_prob: tensor([ 0.1634,  0.2005])

讀到這里,這篇“Pytorch中的torch.distributions庫(kù)怎么使用”文章已經(jīng)介紹完畢,想要掌握這篇文章的知識(shí)點(diǎn)還需要大家自己動(dòng)手實(shí)踐使用過(guò)才能領(lǐng)會(huì),如果想了解更多相關(guān)內(nèi)容的文章,歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI