python的nn.linear怎么使用

小億
135
2023-12-22 10:10:59
欄目: 編程語言

在PyTorch中,nn.Linear是一個(gè)用于定義線性變換的類??梢允褂盟鼇矶x一個(gè)全連接層(也稱為線性層)。

以下是如何使用nn.Linear的例子:

首先,導(dǎo)入需要的模塊:

import torch
import torch.nn as nn

接下來,定義一個(gè)包含輸入和輸出大小的線性層:

input_size = 10
output_size = 5

linear_layer = nn.Linear(input_size, output_size)

這將創(chuàng)建一個(gè)線性層,將輸入維度為10的特征映射到輸出維度為5的特征。

然后,可以將數(shù)據(jù)傳遞給線性層進(jìn)行轉(zhuǎn)換:

input_data = torch.randn(1, input_size)
output_data = linear_layer(input_data)

這將根據(jù)線性層的權(quán)重和偏差將輸入數(shù)據(jù)進(jìn)行線性變換,并返回輸出數(shù)據(jù)。

最后,可以查看線性層的權(quán)重和偏差:

print(linear_layer.weight)
print(linear_layer.bias)

這將打印出線性層的權(quán)重矩陣和偏差向量。

注意:nn.Linear類還可以接受一些其他參數(shù),例如是否添加偏差(默認(rèn)為True)、權(quán)重初始化方法等。你可以查閱PyTorch的官方文檔以獲取更多詳細(xì)信息。

0