溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點(diǎn)擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

pytorch中Parameter函數(shù)怎么使用

發(fā)布時(shí)間:2022-02-07 15:20:42 來源:億速云 閱讀:286 作者:iii 欄目:開發(fā)技術(shù)

這篇文章主要介紹了pytorch中Parameter函數(shù)怎么使用的相關(guān)知識,內(nèi)容詳細(xì)易懂,操作簡單快捷,具有一定借鑒價(jià)值,相信大家閱讀完這篇pytorch中Parameter函數(shù)怎么使用文章都會有所收獲,下面我們一起來看看吧。

用法介紹

pytorch中的Parameter函數(shù)可以對某個(gè)張量進(jìn)行參數(shù)化。它可以將不可訓(xùn)練的張量轉(zhuǎn)化為可訓(xùn)練的參數(shù)類型,同時(shí)將轉(zhuǎn)化后的張量綁定到模型可訓(xùn)練參數(shù)的列表中,當(dāng)更新模型的參數(shù)時(shí)一并將其更新。

torch.nn.parameter.Parameter

  • data (Tensor):表示需要參數(shù)化的張量

  • requires_grad (bool, optional):表示是否該張量是否需要梯度,默認(rèn)值為True

代碼介紹

 pytorch中的Parameter函數(shù)具體的代碼示例如下所示

import torch
import torch.nn as nn
class NeuralNetwork(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(NeuralNetwork, self).__init__()
		self.linear = nn.Linear(input_dim, output_dim)
		self.linear.weight = torch.nn.Parameter(torch.zeros(input_dim, output_dim))
		self.linear.bias = torch.nn.Parameter(torch.ones(output_dim))
	def forward(self, input_array):
		output = self.linear(input_array)
		return output
if __name__ == '__main__':
	net = NeuralNetwork(4, 6)
	for param in net.parameters():
		print(param)

代碼的結(jié)果如下所示:

pytorch中Parameter函數(shù)怎么使用

當(dāng)神經(jīng)網(wǎng)絡(luò)的參數(shù)不是用Parameter函數(shù)參數(shù)化直接賦值給權(quán)重參數(shù)時(shí),則會報(bào)錯(cuò),具體的程序

import torch
import torch.nn as nn
class NeuralNetwork(nn.Module):
	def __init__(self, input_dim, output_dim):
		super(NeuralNetwork, self).__init__()
		self.linear = nn.Linear(input_dim, output_dim)
		self.linear.weight = torch.zeros(input_dim, output_dim)
		self.linear.bias = torch.ones(output_dim)
	def forward(self, input_array):
		output = self.linear(input_array)
		return output
if __name__ == '__main__':
	net = NeuralNetwork(4, 6)
	for param in net.parameters():
		print(param)

代碼運(yùn)行報(bào)錯(cuò)結(jié)果如下所示:

pytorch中Parameter函數(shù)怎么使用

關(guān)于“pytorch中Parameter函數(shù)怎么使用”這篇文章的內(nèi)容就介紹到這里,感謝各位的閱讀!相信大家對“pytorch中Parameter函數(shù)怎么使用”知識都有一定的了解,大家如果還想學(xué)習(xí)更多知識,歡迎關(guān)注億速云行業(yè)資訊頻道。

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI