您好,登錄后才能下訂單哦!
本篇內(nèi)容介紹了“Python回歸樹如何實(shí)現(xiàn)”的有關(guān)知識,在實(shí)際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領(lǐng)大家學(xué)習(xí)一下如何處理這些情況吧!希望大家仔細(xì)閱讀,能夠?qū)W有所成!
首先導(dǎo)入庫
import pandas as pd import numpy as np import matplotlib.pyplot as plt
首先需要?jiǎng)?chuàng)建訓(xùn)練數(shù)據(jù),我們的數(shù)據(jù)將具有獨(dú)立變量(x)和一個(gè)相關(guān)的變量(y),并使用numpy在相關(guān)值中添加高斯噪聲,可以用數(shù)學(xué)表達(dá)為
這里的???? 是噪聲。代碼如下所示。
def f(x): mu, sigma = 0, 1.5 return -x**2 + x + 5 + np.random.normal(mu, sigma, 1) num_points = 300 np.random.seed(1) x = np.random.uniform(-2, 5, num_points) y = np.array( [f(i) for i in x] ) plt.scatter(x, y, s = 5)
在回歸樹中是通過創(chuàng)建一個(gè)多個(gè)節(jié)點(diǎn)的樹來預(yù)測數(shù)值數(shù)據(jù)的。下圖展示了一個(gè)回歸樹的樹結(jié)構(gòu)示例,其中每個(gè)節(jié)點(diǎn)都有其用于劃分?jǐn)?shù)據(jù)的閾值。
給定一組數(shù)據(jù),輸入值將通過相應(yīng)的規(guī)格達(dá)到葉子節(jié)點(diǎn)。達(dá)到節(jié)點(diǎn)M的所有輸入值可以用X的子集表示。從數(shù)學(xué)上講,讓我們用一個(gè)函數(shù)表達(dá)此情況,如果給定的輸入值達(dá)到節(jié)點(diǎn)M,則可以給出1個(gè),否則為0。
找到分裂數(shù)據(jù)的閾值:通過在每個(gè)步驟中選擇2個(gè)連續(xù)點(diǎn)并計(jì)算其平均值來迭代訓(xùn)練數(shù)據(jù)。計(jì)算的平均值將數(shù)據(jù)分為兩個(gè)的閾值。
首先讓我們考慮隨機(jī)閾值以演示任何給定的情況。
threshold = 1.5 low = np.take(y, np.where(x < threshold)) high = np.take(y, np.where(x > threshold)) plt.scatter(x, y, s = 5, label = 'Data') plt.plot([threshold]*2, [-16, 10], 'b--', label = 'Threshold line') plt.plot([-2, threshold], [low.mean()]*2, 'r--', label = 'Left child prediction line') plt.plot([threshold, 5], [high.mean()]*2, 'r--', label = 'Right child prediction line') plt.plot([-2, 5], [y.mean()]*2, 'g--', label = 'Node prediction line') plt.legend()
藍(lán)色垂直線表示單個(gè)閾值,我們假設(shè)它是任意兩點(diǎn)的均值,并稍后將其用于劃分?jǐn)?shù)據(jù)。
我們對這個(gè)問題的第一個(gè)預(yù)測是所有訓(xùn)練數(shù)據(jù)(y軸)的平均值(綠色水平線)。而兩條紅線是要?jiǎng)?chuàng)建的子節(jié)點(diǎn)的預(yù)測。
很明顯這些平均值都不能很好地代表我們的數(shù)據(jù),但它們的差異也是很明顯的:主節(jié)點(diǎn)預(yù)測(綠線)得到所有訓(xùn)練數(shù)據(jù)的均值,我們將其分為2個(gè)子節(jié)點(diǎn),這2個(gè)子節(jié)點(diǎn)有自己的預(yù)測(紅線)。與綠線相比這2個(gè)子節(jié)點(diǎn)更好地代表了它們對應(yīng)的訓(xùn)練數(shù)據(jù)。回歸樹就是將不斷地將數(shù)據(jù)分成2個(gè)部分——從每個(gè)節(jié)點(diǎn)創(chuàng)建2個(gè)子節(jié)點(diǎn),直到達(dá)到給定的停止值(這是一個(gè)節(jié)點(diǎn)所能擁有的最小數(shù)據(jù)量)。它會提前停止樹的構(gòu)建過程,我們將其稱為預(yù)修剪樹。
為什么會有早停的機(jī)制?如果我們要繼續(xù)進(jìn)行分配直到節(jié)點(diǎn)只有一個(gè)值是,這創(chuàng)建一個(gè)過度擬合的方案,每個(gè)訓(xùn)練數(shù)據(jù)都只能預(yù)測自己。
說明:當(dāng)模型完成時(shí),它不會使用根節(jié)點(diǎn)或任何中間節(jié)點(diǎn)來預(yù)測任何值;它將使用回歸樹的葉子(這將是樹的最后一個(gè)節(jié)點(diǎn))進(jìn)行預(yù)測。
為了得到最能代表給定閾值數(shù)據(jù)的閾值,我們使用殘差平方和。它可以在數(shù)學(xué)上定義為
讓我們看看這一步是如何工作的。
既然計(jì)算了閾值的SSR值,那么可以采用具有最小SSR值的閾值。使用該閾值將訓(xùn)練數(shù)據(jù)分為兩個(gè)(低和高部分),其中其中低部分將用于創(chuàng)建左子節(jié)點(diǎn),高部分將用于創(chuàng)建右子節(jié)點(diǎn)。
def SSR(r, y): return np.sum( (r - y)**2 ) SSRs, thresholds = [], [] for i in range(len(x) - 1): threshold = x[i:i+2].mean() low = np.take(y, np.where(x < threshold)) high = np.take(y, np.where(x > threshold)) guess_low = low.mean() guess_high = high.mean() SSRs.append(SSR(low, guess_low) + SSR(high, guess_high)) thresholds.append(threshold) print('Minimum residual is: {:.2f}'.format(min(SSRs))) print('Corresponding threshold value is: {:.4f}'.format(thresholds[SSRs.index(min(SSRs))]))
在進(jìn)入下一步之前,我將使用pandas創(chuàng)建一個(gè)df,并創(chuàng)建一個(gè)用于尋找最佳閾值的方法。所有這些步驟都可以在沒有pandas的情況下完成,這里使用他是因?yàn)楸容^方便。
df = pd.DataFrame(zip(x, y.squeeze()), columns = ['x', 'y']) def find_threshold(df, plot = False): SSRs, thresholds = [], [] for i in range(len(df) - 1): threshold = df.x[i:i+2].mean() low = df[(df.x <= threshold)] high = df[(df.x > threshold)] guess_low = low.y.mean() guess_high = high.y.mean() SSRs.append(SSR(low.y.to_numpy(), guess_low) + SSR(high.y.to_numpy(), guess_high)) thresholds.append(threshold) if plot: plt.scatter(thresholds, SSRs, s = 3) plt.show() return thresholds[SSRs.index(min(SSRs))]
在將數(shù)據(jù)分成兩個(gè)部分后就可以為低值和高值找到單獨(dú)的閾值。需要注意的是這里要增加一個(gè)停止條件;因?yàn)閷τ诿總€(gè)節(jié)點(diǎn),屬于該節(jié)點(diǎn)的數(shù)據(jù)集中的點(diǎn)會變少,所以我們?yōu)槊總€(gè)節(jié)點(diǎn)定義了最小數(shù)據(jù)點(diǎn)數(shù)量。如果不這樣做,每個(gè)節(jié)點(diǎn)將只使用一個(gè)訓(xùn)練值進(jìn)行預(yù)測,會導(dǎo)致過擬合。
可以遞歸地創(chuàng)建節(jié)點(diǎn),我們定義了一個(gè)名為TreeNode的類,它將存儲節(jié)點(diǎn)應(yīng)該存儲的每一個(gè)值。使用這個(gè)類我們首先創(chuàng)建根,同時(shí)計(jì)算它的閾值和預(yù)測值。然后遞歸地創(chuàng)建它的子節(jié)點(diǎn),其中每個(gè)子節(jié)點(diǎn)類都存儲在父類的left或right屬性中。
在下面的create_nodes方法中,首先將給定的df分成兩部分。然后檢查是否有足夠的數(shù)據(jù)單獨(dú)創(chuàng)建左右節(jié)點(diǎn)。如果(對于其中任何一個(gè))有足夠的數(shù)據(jù)點(diǎn),我們計(jì)算閾值并使用它創(chuàng)建一個(gè)子節(jié)點(diǎn),用這個(gè)新節(jié)點(diǎn)作為樹再次調(diào)用create_nodes方法。
class TreeNode(): def __init__(self, threshold, pred): self.threshold = threshold self.pred = pred self.left = None self.right = None def create_nodes(tree, df, stop): low = df[df.x <= tree.threshold] high = df[df.x > tree.threshold] if len(low) > stop: threshold = find_threshold(low) tree.left = TreeNode(threshold, low.y.mean()) create_nodes(tree.left, low, stop) if len(high) > stop: threshold = find_threshold(high) tree.right = TreeNode(threshold, high.y.mean()) create_nodes(tree.right, high, stop) threshold = find_threshold(df) tree = TreeNode(threshold, df.y.mean()) create_nodes(tree, df, 5)
這個(gè)方法在第一棵樹上進(jìn)行了修改,因?yàn)樗恍枰祷厝魏螙|西。雖然遞歸函數(shù)通常不是這樣寫的(不返回),但因?yàn)椴恍枰祷刂?,所以?dāng)沒有激活if語句時(shí),不做任何操作。
在完成后可以檢查此樹結(jié)構(gòu),查看它是否創(chuàng)建了一些可以擬合數(shù)據(jù)的節(jié)點(diǎn)。這里將手動(dòng)選擇第一個(gè)節(jié)點(diǎn)及其對根閾值的預(yù)測。
plt.scatter(x, y, s = 0.5, label = 'Data') plt.plot([tree.threshold]*2, [-16, 10], 'r--', label = 'Root threshold') plt.plot([tree.right.threshold]*2, [-16, 10], 'g--', label = 'Right node threshold') plt.plot([tree.threshold, tree.right.threshold], [tree.right.left.pred]*2, 'g', label = 'Right node prediction') plt.plot([tree.left.threshold]*2, [-16, 10], 'm--', label = 'Left node threshold') plt.plot([tree.left.threshold, tree.threshold], [tree.left.right.pred]*2, 'm', label = 'Left node prediction') plt.plot([tree.left.left.threshold]*2, [-16, 10], 'k--', label = 'Second Left node threshold') plt.legend()
這里看到了兩個(gè)預(yù)測:
第一個(gè)左節(jié)點(diǎn)對高值的預(yù)測(高于其閾值)
第一個(gè)右節(jié)點(diǎn)對低值(低于其閾值)的預(yù)測
這里我手動(dòng)剪切了預(yù)測線的寬度,因?yàn)槿绻o定的x值達(dá)到了這些節(jié)點(diǎn)中的任何一個(gè),則將以屬于該節(jié)點(diǎn)的所有x值的平均值表示,這也意味著沒有其他x值參與 在該節(jié)點(diǎn)的預(yù)測中(希望有意義)。
這種樹形結(jié)構(gòu)遠(yuǎn)不止兩個(gè)節(jié)點(diǎn)那么簡單,所以我們可以通過如下調(diào)用它的子節(jié)點(diǎn)來檢查一個(gè)特定的葉子節(jié)點(diǎn)。
tree.left.right.left.left
這當(dāng)然意味著這里有一個(gè)向下4個(gè)子結(jié)點(diǎn)長的分支,但它可以在樹的另一個(gè)分支上深入得多。
我們可以創(chuàng)建一個(gè)預(yù)測方法來預(yù)測任何給定的值。
def predict(x): curr_node = tree result = None while True: if x <= curr_node.threshold: if curr_node.left: curr_node = curr_node.left else: break elif x > curr_node.threshold: if curr_node.right: curr_node = curr_node.right else: break return curr_node.pred
預(yù)測方法做的是沿著樹向下,通過比較我們的輸入和每個(gè)葉子的閾值。如果輸入值大于閾值,則轉(zhuǎn)到右葉,如果小于閾值,則轉(zhuǎn)到左葉,以此類推,直到到達(dá)任何底部葉子節(jié)點(diǎn)。然后使用該節(jié)點(diǎn)自身的預(yù)測值進(jìn)行預(yù)測,并與其閾值進(jìn)行最后的比較。
使用x = 3進(jìn)行測試(在創(chuàng)建數(shù)據(jù)時(shí),可以使用上面所寫的函數(shù)計(jì)算實(shí)際值。-3**2+3+5 = -1,這是期望值),我們得到:
predict(3) # -1.23741
這里用相對平方誤差驗(yàn)證數(shù)據(jù)
def RSE(y, g): return sum(np.square(y - g)) / sum(np.square(y - 1 / len(y)*sum(y))) x_val = np.random.uniform(-2, 5, 50) y_val = np.array( [f(i) for i in x_val] ).squeeze() tr_preds = np.array( [predict(i) for i in df.x] ) val_preds = np.array( [predict(i) for i in x_val] ) print('Training error: {:.4f}'.format(RSE(df.y, tr_preds))) print('Validation error: {:.4f}'.format(RSE(y_val, val_preds)))
可以看到誤差并不大,結(jié)果如下
一個(gè)更適合回歸樹模型的數(shù)據(jù):因?yàn)槲覀兊臄?shù)據(jù)是多項(xiàng)式生成的數(shù)據(jù),所以使用多項(xiàng)式回歸模型可以更好地?cái)M合。我們更換一下訓(xùn)練數(shù)據(jù),把新函數(shù)設(shè)為
def f(x): mu, sigma = 0, 0.5 if x < 3: return 1 + np.random.normal(mu, sigma, 1) elif x >= 3 and x < 6: return 9 + np.random.normal(mu, sigma, 1) elif x >= 6: return 5 + np.random.normal(mu, sigma, 1) np.random.seed(1) x = np.random.uniform(0, 10, num_points) y = np.array( [f(i) for i in x] ) plt.scatter(x, y, s = 5)
在此數(shù)據(jù)集上運(yùn)行了上面的所有相同過程,結(jié)果如下
比我們從多項(xiàng)式數(shù)據(jù)中獲得的誤差低。
最后共享一下上面動(dòng)圖的代碼:
import pandas as pd import numpy as np import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation #===================================================Create Data def f(x): mu, sigma = 0, 1.5 return -x**2 + x + 5 + np.random.normal(mu, sigma, 1) np.random.seed(1) x = np.random.uniform(-2, 5, 300) y = np.array( [f(i) for i in x] ) p = x.argsort() x = x[p] y = y[p] #===================================================Calculate Thresholds def SSR(r, y): #send numpy array return np.sum( (r - y)**2 ) SSRs, thresholds = [], [] for i in range(len(x) - 1): threshold = x[i:i+2].mean() low = np.take(y, np.where(x < threshold)) high = np.take(y, np.where(x > threshold)) guess_low = low.mean() guess_high = high.mean() SSRs.append(SSR(low, guess_low) + SSR(high, guess_high)) thresholds.append(threshold) #===================================================Animated Plot fig, (ax1, ax2) = plt.subplots(2,1, sharex = True) x_data, y_data = [], [] x_data2, y_data2 = [], [] ln, = ax1.plot([], [], 'r--') ln2, = ax2.plot(thresholds, SSRs, 'ro', markersize = 2) line = [ln, ln2] def init(): ax1.scatter(x, y, s = 3) ax1.title.set_text('Trying Different Thresholds') ax2.title.set_text('Threshold vs SSR') ax1.set_ylabel('y values') ax2.set_xlabel('Threshold') ax2.set_ylabel('SSR') return line def update(frame): x_data = [x[frame:frame+2].mean()] * 2 y_data = [min(y), max(y)] line[0].set_data(x_data, y_data) x_data2.append(thresholds[frame]) y_data2.append(SSRs[frame]) line[1].set_data(x_data2, y_data2) return line ani = FuncAnimation(fig, update, frames = 298, init_func = init, blit = True) plt.show()
“Python回歸樹如何實(shí)現(xiàn)”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識可以關(guān)注億速云網(wǎng)站,小編將為大家輸出更多高質(zhì)量的實(shí)用文章!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。