PyTorch中怎么實(shí)現(xiàn)批量歸一化

小億
127
2024-05-10 19:07:56

在PyTorch中,可以使用torch.nn.BatchNorm1dtorch.nn.BatchNorm2d來實(shí)現(xiàn)批量歸一化。具體代碼示例如下:

import torch
import torch.nn as nn

# 對(duì)輸入數(shù)據(jù)進(jìn)行批量歸一化
input_data = torch.randn(20, 16, 50, 50)  # 輸入數(shù)據(jù)的shape為(batch_size, channels, height, width)

# 對(duì)2D數(shù)據(jù)進(jìn)行批量歸一化
batchnorm = nn.BatchNorm2d(16)  # 對(duì)通道維度進(jìn)行批量歸一化
output_data = batchnorm(input_data)

# 對(duì)1D數(shù)據(jù)進(jìn)行批量歸一化
input_data = torch.randn(20, 16, 100)  # 輸入數(shù)據(jù)的shape為(batch_size, channels, length)
batchnorm = nn.BatchNorm1d(16)  # 對(duì)特征維度進(jìn)行批量歸一化
output_data = batchnorm(input_data)

上述代碼中,nn.BatchNorm2d用于對(duì)2D數(shù)據(jù)(如圖像數(shù)據(jù))進(jìn)行批量歸一化,nn.BatchNorm1d用于對(duì)1D數(shù)據(jù)進(jìn)行批量歸一化。需要注意的是,這兩個(gè)函數(shù)都會(huì)自動(dòng)計(jì)算并更新均值和方差,同時(shí)也會(huì)學(xué)習(xí)伽馬和貝塔參數(shù)來進(jìn)行縮放和偏移。

0