如何在PyTorch中創(chuàng)建一個(gè)神經(jīng)網(wǎng)絡(luò)模型

小樊
91
2024-03-05 18:11:04

在PyTorch中創(chuàng)建神經(jīng)網(wǎng)絡(luò)模型通常需要定義一個(gè)繼承自torch.nn.Module類的自定義類。下面是一個(gè)簡(jiǎn)單的示例:

import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)  # 定義一個(gè)全連接層
        self.relu = nn.ReLU()  # 定義一個(gè)激活函數(shù)
        self.fc2 = nn.Linear(128, 10)  # 定義另一個(gè)全連接層

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

在上面的示例中,我們定義了一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)模型SimpleNN,包括兩個(gè)全連接層和一個(gè)ReLU激活函數(shù)。__init__方法用于定義模型的結(jié)構(gòu),forward方法用于定義模型的前向傳播過(guò)程。

要使用這個(gè)模型,可以實(shí)例化一個(gè)對(duì)象并傳入輸入數(shù)據(jù)進(jìn)行前向傳播計(jì)算:

model = SimpleNN()
input_data = torch.randn(1, 784)  # 創(chuàng)建一個(gè)輸入數(shù)據(jù)張量
output = model(input_data)  # 進(jìn)行前向傳播
print(output)

這樣就可以在PyTorch中創(chuàng)建一個(gè)簡(jiǎn)單的神經(jīng)網(wǎng)絡(luò)模型了。您可以根據(jù)自己的需求定義更復(fù)雜的模型結(jié)構(gòu)和前向傳播過(guò)程。

0