溫馨提示×

Scikit-learn中怎么繪制驗(yàn)證曲線

小億
85
2024-05-10 18:19:55
欄目: 編程語言

在Scikit-learn中,可以使用validation_curve函數(shù)來繪制驗(yàn)證曲線。該函數(shù)可以用于評估模型在不同超參數(shù)取值下的訓(xùn)練集和驗(yàn)證集上的表現(xiàn)。

以下是一個簡單的示例,展示如何繪制一個決策樹模型的最大深度對應(yīng)的驗(yàn)證曲線:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import validation_curve
from sklearn.tree import DecisionTreeClassifier

# 創(chuàng)建決策樹模型
model = DecisionTreeClassifier()

# 定義不同的最大深度取值范圍
param_range = np.arange(1, 10)

# 使用validation_curve函數(shù)計(jì)算訓(xùn)練集和驗(yàn)證集的得分
train_scores, test_scores = validation_curve(
    model, X, y, param_name="max_depth", param_range=param_range,
    scoring="accuracy", n_jobs=1)

# 計(jì)算平均得分
train_scores_mean = np.mean(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)

# 繪制驗(yàn)證曲線
plt.figure()
plt.plot(param_range, train_scores_mean, label="Training score", color="r")
plt.plot(param_range, test_scores_mean, label="Cross-validation score", color="b")
plt.xlabel("max_depth")
plt.ylabel("Score")
plt.title("Validation Curve")
plt.legend(loc="best")
plt.show()

在這個示例中,我們使用DecisionTreeClassifier模型,針對最大深度參數(shù)進(jìn)行了驗(yàn)證曲線的繪制。您可以根據(jù)需要替換模型和超參數(shù),來繪制其他模型的驗(yàn)證曲線。

0