溫馨提示×

PyTorch PyG能支持自定義層嗎

小樊
81
2024-10-22 06:08:58

PyTorch的PyG庫可以支持自定義層。在PyTorch中,可以通過繼承torch.nn.Module類來創(chuàng)建自定義層。例如,定義一個簡單的全連接層,可以這樣做:

import torch
import torch.nn as nn

class MyLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MyLayer, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

在這個例子中,MyLayer類繼承自nn.Module,并定義了一個全連接層self.linear。在forward方法中,我們將輸入x傳遞給這個全連接層,并返回其輸出。

然后,在使用PyG庫時,可以將這個自定義層添加到圖結(jié)構(gòu)中。例如,定義一個包含自定義層和PyTorch nn.Linear層的圖結(jié)構(gòu):

from torch_geometric.nn import MessagePassing
import torch

class MyModel(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MyModel, self).__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels)
        self.my_layer = MyLayer(in_channels, 64)

    def forward(self, x, edge_index):
        row, col = edge_index
        x = self.my_layer(x)
        x = self.lin(x)
        row, col = row.view(-1, 1), col.view(-1, 1)
        deg = self.degree(row, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def degree(self, row, num_nodes, dtype):
        row, col = row.to(dtype), col.to(dtype)
        deg = torch.bincount(row, minlength=num_nodes, dtype=dtype)
        deg = deg[row] + deg[col]
        return deg.view(-1, 1)

在這個例子中,MyModel類繼承自MessagePassing,并定義了一個包含自定義層self.my_layer和PyTorch nn.Linear層的圖結(jié)構(gòu)。在forward方法中,我們首先對輸入x應(yīng)用自定義層,然后應(yīng)用線性層,最后根據(jù)邊的權(quán)重計算消息和更新節(jié)點特征。

0