溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

Java線性回歸基礎(chǔ)代碼怎么寫

發(fā)布時(shí)間:2021-12-21 10:44:24 來(lái)源:億速云 閱讀:140 作者:iii 欄目:編程語(yǔ)言

這篇文章主要介紹“Java線性回歸基礎(chǔ)代碼怎么寫”,在日常操作中,相信很多人在Java線性回歸基礎(chǔ)代碼怎么寫問(wèn)題上存在疑惑,小編查閱了各式資料,整理出簡(jiǎn)單好用的操作方法,希望對(duì)大家解答”Java線性回歸基礎(chǔ)代碼怎么寫”的疑惑有所幫助!接下來(lái),請(qǐng)跟著小編一起來(lái)學(xué)習(xí)吧!

# Use linear model to model this data.
from sklearn.linear_model import LinearRegression
import numpy as np
lr=LinearRegression()
lr.fit(pga.distance[:,np.newaxis],pga['accuracy']) # Another way is using pga[['distance']]
theta0=lr.intercept_
theta1=lr.coef_
print(theta0)
print(theta1)
#calculating cost-function for each theta1
#計(jì)算平均累積誤差
def cost(x,y,theta0,theta1):
 J=0
 for i in range(len(x)):
 mse=(x[i]*theta1+theta0-y[i])**2
 J+=mse
 return J/(2*len(x))
theta0=100
theta1s = np.linspace(-3,2,197)
costs=[]
for theta1 in theta1s:
 costs.append(cost(pga['distance'],pga['accuracy'],theta0,theta1))
plt.plot(theta1s,costs)
plt.show()
print(pga.distance)
#調(diào)整theta
def partial_cost_theta0(x,y,theta0,theta1):
 #我們的模型是線性擬合函數(shù)時(shí):y=theta1*x + theta0,而不是sigmoid函數(shù),當(dāng)非線性時(shí)我們可以用sigmoid
 #直接多整個(gè)x series操作,省的一個(gè)一個(gè)計(jì)算,最終求sum 再平均
 h=theta1*x+theta0 
 diff=(h-y)
 partial=diff.sum()/len(diff)
 return partial
partial0=partial_cost_theta0(pga.distance,pga.accuracy,1,1)
def partial_cost_theta1(x,y,theta0,theta1):
 #我們的模型是線性擬合函數(shù):y=theta1*x + theta0,而不是sigmoid函數(shù),當(dāng)非線性時(shí)我們可以用sigmoid
 h=theta1*x+theta0
 diff=(h-y)*x
 partial=diff.sum()/len(diff)
 return partial
partial1=partial_cost_theta1(pga.distance,pga.accuracy,0,5)
print(partial0)
print(partial1)
def gradient_descent(x,y,alpha=0.1,theta0=0,theta1=0): #設(shè)置默認(rèn)參數(shù)
 #計(jì)算成本
 #調(diào)整權(quán)值
 #計(jì)算錯(cuò)誤代價(jià),判斷是否收斂或者達(dá)到最大迭代次數(shù)
 most_iterations=1000
 convergence_thres=0.000001 
 
 c=cost(x,y,theta0,theta1)
 costs=[c]
 cost_pre=c+convergence_thres+1.0
 
 counter=0
 while( (np.abs(c-cost_pre)>convergence_thres) & (counter<most_iterations) ):
 update0=alpha*partial_cost_theta0(x,y,theta0,theta1)
 update1=alpha*partial_cost_theta1(x,y,theta0,theta1)
 
 theta0-=update0
 theta1-=update1
 cost_pre=c
 c=cost(x,y,theta0,theta1)
 costs.append(c)
 counter+=1
 return {'theta0': theta0, 'theta1': theta1, "costs": costs}
print("Theta1 =", gradient_descent(pga.distance, pga.accuracy)['theta1'])
costs=gradient_descent(pga.distance,pga.accuracy,alpha=.01)['cost']
print(gradient_descent(pga.distance, pga.accuracy,alpha=.01)['theta1'])
plt.scatter(range(len(costs)),costs)
plt.show()
預(yù)覽

數(shù)據(jù)集 :

復(fù)制下面數(shù)據(jù),保存為: pga.csv

distance,accuracy
290.3,59.5
302.1,54.7
287.1,62.4
282.7,65.4
299.1,52.8
300.2,51.1
300.9,58.3
279.5,73.9
287.8,67.6
284.7,67.2
296.7,60
283.3,59.4
284,72.2
292,62.1
282.6,66.5
287.9,60.9
279.2,67.3
291.7,64.8
289.9,58.1
289.8,61.7
298.8,56.4
280.8,60.5
294.9,57.5
287.5,61.8
282.7,56
277.7,72.5
270.5,71.7
285.2,66
315.1,55.2
281.9,67.6
293.3,58.2
286,59.9
285.6,58.2
289.9,65.7
277.5,59
293.6,56.8
301.1,65.4
300.8,63.4
287.4,67.3
281.8,72.6
277.4,63.1
279.1,66.5
287.4,66.4
280.9,62.3
287.8,57.2
261.4,69.2
272.6,69.4
291.3,65.3
294.2,52.8
285.5,49
287.9,61.1
282.2,65.6
301.3,58.2
276.2,61.7
281.6,68.1
275.5,61.2
309.7,53.1
287.7,56.4
291.6,56.9
284.1,65
299.6,57.5
282.7,60
271.5,72
292.1,58.2
295,59.4
274.9,69
273.6,68.7
299.9,60.1
279.9,74
289.9,66
283.6,59.8
310.3,52.4
291.7,65.6
284.2,63.2
295,53.5
298.6,55.1
297.4,60.4
299.7,67.7
284.4,69.7
286.4,72.4
285.9,66.9
297.6,54.3
272.5,62
277,66.2
287.6,60.9
280.4,69.4
280,63.7
295.4,52.8
274.4,68.8
286.5,73.1
287.7,65.2
291.5,65.9
279,69.4
299,65.2
290.1,69.1
288.9,67.9
288.8,68.2
283.2,61
293.2,58.4
285.3,67.3
284.1,65.7
281.4,67.7
286.1,61.4
284.9,62.3
284.8,68.1
296,62
282.9,71.8
280.9,67.8
291.2,62
292.8,62.2
291,61.9
285.7,62.4
283.9,62.9
298.4,61.5
285.1,65.3
286.1,60.1
283.1,65.4
289.4,58.3
284.6,70.7
296.6,62.3
295.9,64.9
295.2,62.8
293.9,54.5
275,65.5
286.8,69.5
291.1,64.4
284.8,62.5
283.7,59.5
295.4,66.9
291.8,62.7
274.9,72.3
302.9,61.2
272.1,80.4
274.9,74.9
296.3,59.4
286.2,58.8
294.2,63.3
284.1,66.5
299.2,62.4
275.4,71
273.2,70.9
281.6,65.9
295.7,55.3
287.1,56.8
287.7,66.9
296.7,53.7
282.2,64.2
291.7,65.6
281.6,73.4
311,56.2
278.6,64.7
288,65.7
276.7,72.1
292,62
286.4,69.9
292.7,65.7
294.2,62.9
278.6,59.6
283.1,69.2
284.1,66
278.6,73.6
291.1,60.4
294.6,59.4
274.3,70.5
274,57.1
283.8,62.7
272.7,66.9
303.2,58.3
282,70.4
281.9,61
287,59.9
293.5,63.8
283.6,56.3
296.9,55.3
290.9,58.2
303,58.1
292.8,61.1
281.1,65
293,61.1
284,66.5
279.8,66.7
292.9,65.4
284,66.9
282,64.5
280.6,64
287.7,63.4
287.7,63.4
298.3,59.5
299.6,53.4
291.3,62.5
295.2,61.4
288,62.4
297.8,59.5
286,62.6
285.3,66.2
286.9,63.4
275.1,73.7

到此,關(guān)于“Java線性回歸基礎(chǔ)代碼怎么寫”的學(xué)習(xí)就結(jié)束了,希望能夠解決大家的疑惑。理論與實(shí)踐的搭配能更好的幫助大家學(xué)習(xí),快去試試吧!若想繼續(xù)學(xué)習(xí)更多相關(guān)知識(shí),請(qǐng)繼續(xù)關(guān)注億速云網(wǎng)站,小編會(huì)繼續(xù)努力為大家?guī)?lái)更多實(shí)用的文章!

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI