您好,登錄后才能下訂單哦!
這篇文章主要介紹“Java線性回歸基礎(chǔ)代碼怎么寫”,在日常操作中,相信很多人在Java線性回歸基礎(chǔ)代碼怎么寫問(wèn)題上存在疑惑,小編查閱了各式資料,整理出簡(jiǎn)單好用的操作方法,希望對(duì)大家解答”Java線性回歸基礎(chǔ)代碼怎么寫”的疑惑有所幫助!接下來(lái),請(qǐng)跟著小編一起來(lái)學(xué)習(xí)吧!
# Use linear model to model this data. from sklearn.linear_model import LinearRegression import numpy as np lr=LinearRegression() lr.fit(pga.distance[:,np.newaxis],pga['accuracy']) # Another way is using pga[['distance']] theta0=lr.intercept_ theta1=lr.coef_ print(theta0) print(theta1) #calculating cost-function for each theta1 #計(jì)算平均累積誤差 def cost(x,y,theta0,theta1): J=0 for i in range(len(x)): mse=(x[i]*theta1+theta0-y[i])**2 J+=mse return J/(2*len(x)) theta0=100 theta1s = np.linspace(-3,2,197) costs=[] for theta1 in theta1s: costs.append(cost(pga['distance'],pga['accuracy'],theta0,theta1)) plt.plot(theta1s,costs) plt.show() print(pga.distance) #調(diào)整theta def partial_cost_theta0(x,y,theta0,theta1): #我們的模型是線性擬合函數(shù)時(shí):y=theta1*x + theta0,而不是sigmoid函數(shù),當(dāng)非線性時(shí)我們可以用sigmoid #直接多整個(gè)x series操作,省的一個(gè)一個(gè)計(jì)算,最終求sum 再平均 h=theta1*x+theta0 diff=(h-y) partial=diff.sum()/len(diff) return partial partial0=partial_cost_theta0(pga.distance,pga.accuracy,1,1) def partial_cost_theta1(x,y,theta0,theta1): #我們的模型是線性擬合函數(shù):y=theta1*x + theta0,而不是sigmoid函數(shù),當(dāng)非線性時(shí)我們可以用sigmoid h=theta1*x+theta0 diff=(h-y)*x partial=diff.sum()/len(diff) return partial partial1=partial_cost_theta1(pga.distance,pga.accuracy,0,5) print(partial0) print(partial1) def gradient_descent(x,y,alpha=0.1,theta0=0,theta1=0): #設(shè)置默認(rèn)參數(shù) #計(jì)算成本 #調(diào)整權(quán)值 #計(jì)算錯(cuò)誤代價(jià),判斷是否收斂或者達(dá)到最大迭代次數(shù) most_iterations=1000 convergence_thres=0.000001 c=cost(x,y,theta0,theta1) costs=[c] cost_pre=c+convergence_thres+1.0 counter=0 while( (np.abs(c-cost_pre)>convergence_thres) & (counter<most_iterations) ): update0=alpha*partial_cost_theta0(x,y,theta0,theta1) update1=alpha*partial_cost_theta1(x,y,theta0,theta1) theta0-=update0 theta1-=update1 cost_pre=c c=cost(x,y,theta0,theta1) costs.append(c) counter+=1 return {'theta0': theta0, 'theta1': theta1, "costs": costs} print("Theta1 =", gradient_descent(pga.distance, pga.accuracy)['theta1']) costs=gradient_descent(pga.distance,pga.accuracy,alpha=.01)['cost'] print(gradient_descent(pga.distance, pga.accuracy,alpha=.01)['theta1']) plt.scatter(range(len(costs)),costs) plt.show() 預(yù)覽
數(shù)據(jù)集 :
復(fù)制下面數(shù)據(jù),保存為: pga.csv
distance,accuracy 290.3,59.5 302.1,54.7 287.1,62.4 282.7,65.4 299.1,52.8 300.2,51.1 300.9,58.3 279.5,73.9 287.8,67.6 284.7,67.2 296.7,60 283.3,59.4 284,72.2 292,62.1 282.6,66.5 287.9,60.9 279.2,67.3 291.7,64.8 289.9,58.1 289.8,61.7 298.8,56.4 280.8,60.5 294.9,57.5 287.5,61.8 282.7,56 277.7,72.5 270.5,71.7 285.2,66 315.1,55.2 281.9,67.6 293.3,58.2 286,59.9 285.6,58.2 289.9,65.7 277.5,59 293.6,56.8 301.1,65.4 300.8,63.4 287.4,67.3 281.8,72.6 277.4,63.1 279.1,66.5 287.4,66.4 280.9,62.3 287.8,57.2 261.4,69.2 272.6,69.4 291.3,65.3 294.2,52.8 285.5,49 287.9,61.1 282.2,65.6 301.3,58.2 276.2,61.7 281.6,68.1 275.5,61.2 309.7,53.1 287.7,56.4 291.6,56.9 284.1,65 299.6,57.5 282.7,60 271.5,72 292.1,58.2 295,59.4 274.9,69 273.6,68.7 299.9,60.1 279.9,74 289.9,66 283.6,59.8 310.3,52.4 291.7,65.6 284.2,63.2 295,53.5 298.6,55.1 297.4,60.4 299.7,67.7 284.4,69.7 286.4,72.4 285.9,66.9 297.6,54.3 272.5,62 277,66.2 287.6,60.9 280.4,69.4 280,63.7 295.4,52.8 274.4,68.8 286.5,73.1 287.7,65.2 291.5,65.9 279,69.4 299,65.2 290.1,69.1 288.9,67.9 288.8,68.2 283.2,61 293.2,58.4 285.3,67.3 284.1,65.7 281.4,67.7 286.1,61.4 284.9,62.3 284.8,68.1 296,62 282.9,71.8 280.9,67.8 291.2,62 292.8,62.2 291,61.9 285.7,62.4 283.9,62.9 298.4,61.5 285.1,65.3 286.1,60.1 283.1,65.4 289.4,58.3 284.6,70.7 296.6,62.3 295.9,64.9 295.2,62.8 293.9,54.5 275,65.5 286.8,69.5 291.1,64.4 284.8,62.5 283.7,59.5 295.4,66.9 291.8,62.7 274.9,72.3 302.9,61.2 272.1,80.4 274.9,74.9 296.3,59.4 286.2,58.8 294.2,63.3 284.1,66.5 299.2,62.4 275.4,71 273.2,70.9 281.6,65.9 295.7,55.3 287.1,56.8 287.7,66.9 296.7,53.7 282.2,64.2 291.7,65.6 281.6,73.4 311,56.2 278.6,64.7 288,65.7 276.7,72.1 292,62 286.4,69.9 292.7,65.7 294.2,62.9 278.6,59.6 283.1,69.2 284.1,66 278.6,73.6 291.1,60.4 294.6,59.4 274.3,70.5 274,57.1 283.8,62.7 272.7,66.9 303.2,58.3 282,70.4 281.9,61 287,59.9 293.5,63.8 283.6,56.3 296.9,55.3 290.9,58.2 303,58.1 292.8,61.1 281.1,65 293,61.1 284,66.5 279.8,66.7 292.9,65.4 284,66.9 282,64.5 280.6,64 287.7,63.4 287.7,63.4 298.3,59.5 299.6,53.4 291.3,62.5 295.2,61.4 288,62.4 297.8,59.5 286,62.6 285.3,66.2 286.9,63.4 275.1,73.7
到此,關(guān)于“Java線性回歸基礎(chǔ)代碼怎么寫”的學(xué)習(xí)就結(jié)束了,希望能夠解決大家的疑惑。理論與實(shí)踐的搭配能更好的幫助大家學(xué)習(xí),快去試試吧!若想繼續(xù)學(xué)習(xí)更多相關(guān)知識(shí),請(qǐng)繼續(xù)關(guān)注億速云網(wǎng)站,小編會(huì)繼續(xù)努力為大家?guī)?lái)更多實(shí)用的文章!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。