溫馨提示×

PyTorch中怎么評估模型性能

小億
125
2024-05-10 15:40:58
欄目: 深度學習

在PyTorch中,可以使用torch.nn.functional模塊中的函數(shù)來評估模型性能。常用的評估方法包括計算準確率、精確度、召回率、F1分數(shù)等。

下面是一些常用的評估方法示例:

  1. 計算準確率:
def accuracy(output, target):
    pred = output.argmax(dim=1, keepdim=True)
    correct = pred.eq(target.view_as(pred)).sum()
    acc = correct.float() / target.size(0)
    return acc
  1. 計算精確度、召回率和F1分數(shù):
from sklearn.metrics import precision_score, recall_score, f1_score

def precision(output, target):
    pred = output.argmax(dim=1, keepdim=True)
    return precision_score(target, pred)

def recall(output, target):
    pred = output.argmax(dim=1, keepdim=True)
    return recall_score(target, pred)

def f1(output, target):
    pred = output.argmax(dim=1, keepdim=True)
    return f1_score(target, pred)

使用這些函數(shù)可以評估模型在測試集上的性能,例如:

model.eval()
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        acc = accuracy(output, target)
        prec = precision(output, target)
        rec = recall(output, target)
        f1 = f1(output, target)
        
        print(f'Accuracy: {acc}, Precision: {prec}, Recall: {rec}, F1 Score: {f1}')

除了以上示例外,還可以根據(jù)具體問題和需求來選擇不同的評估方法。PyTorch提供了靈活的接口,方便用戶根據(jù)需要進行模型性能評估。

0