溫馨提示×

PyTorch中怎么實現池化層

小億
100
2024-05-10 19:10:54
欄目: 深度學習

在PyTorch中,可以使用torch.nn.MaxPool2d來實現池化層。torch.nn.MaxPool2d會對輸入數據進行最大池化操作,即在每個池化窗口內取最大值作為輸出。

以下是一個簡單的例子,演示如何在PyTorch中使用torch.nn.MaxPool2d實現池化層:

import torch
import torch.nn as nn

# 創(chuàng)建一個輸入張量
input_data = torch.rand(1, 1, 5, 5)  # 1個樣本,1個通道,5x5的圖像

# 創(chuàng)建一個最大池化層
max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

# 對輸入數據進行最大池化
output = max_pool(input_data)

print("輸入數據:")
print(input_data)

print("\n最大池化后的輸出:")
print(output)

在這個例子中,我們首先創(chuàng)建了一個5x5的輸入張量,并創(chuàng)建了一個池化窗口大小為2x2的最大池化層。然后我們對輸入數據進行最大池化操作,并輸出池化后的結果。

運行上面的代碼,你將看到輸入數據和經過最大池化后的輸出結果。

0