PyTorch中怎么初始化模型權(quán)重

小億
141
2024-05-10 15:34:54

在PyTorch中,可以通過(guò)定義一個(gè)函數(shù)來(lái)初始化模型的權(quán)重。以下是一個(gè)示例代碼:

import torch
import torch.nn as nn

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)
        nn.init.zeros_(m.bias)

# 定義模型
model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3),
    nn.ReLU(),
    nn.Linear(64*28*28, 10)
)

# 初始化模型權(quán)重
model.apply(init_weights)

在上面的代碼中,定義了一個(gè)init_weights函數(shù),該函數(shù)根據(jù)模型的類(lèi)型對(duì)權(quán)重進(jìn)行初始化。然后通過(guò)調(diào)用model.apply(init_weights)來(lái)初始化模型的權(quán)重。這樣就可以保證模型的權(quán)重被正確地初始化。

0