您好,登錄后才能下訂單哦!
本文小編為大家詳細介紹“PyTorch梯度下降反向傳播實例分析”,內(nèi)容詳細,步驟清晰,細節(jié)處理妥當,希望這篇“PyTorch梯度下降反向傳播實例分析”文章能幫助大家解決疑惑,下面跟著小編的思路慢慢深入,一起來學習新知識吧。
前言:
反向傳播的目的是計算成本函數(shù)C對網(wǎng)絡中任意w或b的偏導數(shù)。一旦我們有了這些偏導數(shù),我們將通過一些常數(shù) α的乘積和該數(shù)量相對于成本函數(shù)的偏導數(shù)來更新網(wǎng)絡中的權重和偏差。這是流行的梯度下降算法。而偏導數(shù)給出了最大上升的方向。因此,關于反向傳播算法,我們繼續(xù)查看下文。
我們向相反的方向邁出了一小步——最大下降的方向,也就是將我們帶到成本函數(shù)的局部最小值的方向
如題:
意思是利用這個二次模型來預測數(shù)據(jù),減小損失函數(shù)(MSE)的值。
代碼如下:
import torch import matplotlib.pyplot as plt import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # 數(shù)據(jù)集 x_data = [1.0,2.0,3.0] y_data = [2.0,4.0,6.0] # 權重參數(shù)初始值均為1 w = torch.tensor([1.0,1.0,1.0]) w.requires_grad = True # 需要計算梯度 # 前向傳播 def forward(x): return w[0]*(x**2)+w[1]*x+w[2] # 計算損失 def loss(x,y): y_pred = forward(x) return (y_pred-y) ** 2 # 訓練模塊 print('predict (before tranining) ',4, forward(4).item()) epoch_list = [] w_list = [] loss_list = [] for epoch in range(1000): for x,y in zip(x_data,y_data): l = loss(x,y) l.backward() # 后向傳播 print('\tgrad: ',x,y,w.grad.data) w.data = w.data - 0.01 * w.grad.data # 梯度下降 w.grad.data.zero_() # 梯度清零操作 print('progress: ',epoch,l.item()) epoch_list.append(epoch) w_list.append(w.data) loss_list.append(l.item()) print('predict (after tranining) ',4, forward(4).item()) # 繪圖 plt.plot(epoch_list,loss_list,'b') plt.xlabel('Epoch') plt.ylabel('Loss') plt.grid() plt.show()
結果如下:
predict (before tranining) 4 21.0 grad: 1.0 2.0 tensor([2., 2., 2.]) grad: 2.0 4.0 tensor([22.8800, 11.4400, 5.7200]) grad: 3.0 6.0 tensor([77.0472, 25.6824, 8.5608]) progress: 0 18.321826934814453 grad: 1.0 2.0 tensor([-1.1466, -1.1466, -1.1466]) grad: 2.0 4.0 tensor([-15.5367, -7.7683, -3.8842]) grad: 3.0 6.0 tensor([-30.4322, -10.1441, -3.3814]) progress: 1 2.858394145965576 grad: 1.0 2.0 tensor([0.3451, 0.3451, 0.3451]) grad: 2.0 4.0 tensor([2.4273, 1.2137, 0.6068]) grad: 3.0 6.0 tensor([19.4499, 6.4833, 2.1611]) progress: 2 1.1675907373428345 grad: 1.0 2.0 tensor([-0.3224, -0.3224, -0.3224]) grad: 2.0 4.0 tensor([-5.8458, -2.9229, -1.4614]) grad: 3.0 6.0 tensor([-3.8829, -1.2943, -0.4314]) progress: 3 0.04653334245085716 grad: 1.0 2.0 tensor([0.0137, 0.0137, 0.0137]) grad: 2.0 4.0 tensor([-1.9141, -0.9570, -0.4785]) grad: 3.0 6.0 tensor([6.8557, 2.2852, 0.7617]) progress: 4 0.14506366848945618 grad: 1.0 2.0 tensor([-0.1182, -0.1182, -0.1182]) grad: 2.0 4.0 tensor([-3.6644, -1.8322, -0.9161]) grad: 3.0 6.0 tensor([1.7455, 0.5818, 0.1939]) progress: 5 0.009403289295732975 grad: 1.0 2.0 tensor([-0.0333, -0.0333, -0.0333]) grad: 2.0 4.0 tensor([-2.7739, -1.3869, -0.6935]) grad: 3.0 6.0 tensor([4.0140, 1.3380, 0.4460]) progress: 6 0.04972923547029495 grad: 1.0 2.0 tensor([-0.0501, -0.0501, -0.0501]) grad: 2.0 4.0 tensor([-3.1150, -1.5575, -0.7788]) grad: 3.0 6.0 tensor([2.8534, 0.9511, 0.3170]) progress: 7 0.025129113346338272 grad: 1.0 2.0 tensor([-0.0205, -0.0205, -0.0205]) grad: 2.0 4.0 tensor([-2.8858, -1.4429, -0.7215]) grad: 3.0 6.0 tensor([3.2924, 1.0975, 0.3658]) progress: 8 0.03345605731010437 grad: 1.0 2.0 tensor([-0.0134, -0.0134, -0.0134]) grad: 2.0 4.0 tensor([-2.9247, -1.4623, -0.7312]) grad: 3.0 6.0 tensor([2.9909, 0.9970, 0.3323]) progress: 9 0.027609655633568764 grad: 1.0 2.0 tensor([0.0033, 0.0033, 0.0033]) grad: 2.0 4.0 tensor([-2.8414, -1.4207, -0.7103]) grad: 3.0 6.0 tensor([3.0377, 1.0126, 0.3375]) progress: 10 0.02848036028444767 grad: 1.0 2.0 tensor([0.0148, 0.0148, 0.0148]) grad: 2.0 4.0 tensor([-2.8174, -1.4087, -0.7043]) grad: 3.0 6.0 tensor([2.9260, 0.9753, 0.3251]) progress: 11 0.02642466314136982 grad: 1.0 2.0 tensor([0.0280, 0.0280, 0.0280]) grad: 2.0 4.0 tensor([-2.7682, -1.3841, -0.6920]) grad: 3.0 6.0 tensor([2.8915, 0.9638, 0.3213]) progress: 12 0.025804826989769936 grad: 1.0 2.0 tensor([0.0397, 0.0397, 0.0397]) grad: 2.0 4.0 tensor([-2.7330, -1.3665, -0.6832]) grad: 3.0 6.0 tensor([2.8243, 0.9414, 0.3138]) progress: 13 0.02462013065814972 grad: 1.0 2.0 tensor([0.0514, 0.0514, 0.0514]) grad: 2.0 4.0 tensor([-2.6934, -1.3467, -0.6734]) grad: 3.0 6.0 tensor([2.7756, 0.9252, 0.3084]) progress: 14 0.023777369409799576 grad: 1.0 2.0 tensor([0.0624, 0.0624, 0.0624]) grad: 2.0 4.0 tensor([-2.6580, -1.3290, -0.6645]) grad: 3.0 6.0 tensor([2.7213, 0.9071, 0.3024]) progress: 15 0.0228563379496336 grad: 1.0 2.0 tensor([0.0731, 0.0731, 0.0731]) grad: 2.0 4.0 tensor([-2.6227, -1.3113, -0.6557]) grad: 3.0 6.0 tensor([2.6725, 0.8908, 0.2969]) progress: 16 0.022044027224183083 grad: 1.0 2.0 tensor([0.0833, 0.0833, 0.0833]) grad: 2.0 4.0 tensor([-2.5893, -1.2946, -0.6473]) grad: 3.0 6.0 tensor([2.6240, 0.8747, 0.2916]) progress: 17 0.02125072106719017 grad: 1.0 2.0 tensor([0.0931, 0.0931, 0.0931]) grad: 2.0 4.0 tensor([-2.5568, -1.2784, -0.6392]) grad: 3.0 6.0 tensor([2.5780, 0.8593, 0.2864]) progress: 18 0.020513182505965233 grad: 1.0 2.0 tensor([0.1025, 0.1025, 0.1025]) grad: 2.0 4.0 tensor([-2.5258, -1.2629, -0.6314]) grad: 3.0 6.0 tensor([2.5335, 0.8445, 0.2815]) progress: 19 0.019810274243354797 grad: 1.0 2.0 tensor([0.1116, 0.1116, 0.1116]) grad: 2.0 4.0 tensor([-2.4958, -1.2479, -0.6239]) grad: 3.0 6.0 tensor([2.4908, 0.8303, 0.2768]) progress: 20 0.019148115068674088 grad: 1.0 2.0 tensor([0.1203, 0.1203, 0.1203]) grad: 2.0 4.0 tensor([-2.4669, -1.2335, -0.6167]) grad: 3.0 6.0 tensor([2.4496, 0.8165, 0.2722]) progress: 21 0.018520694226026535 grad: 1.0 2.0 tensor([0.1286, 0.1286, 0.1286]) grad: 2.0 4.0 tensor([-2.4392, -1.2196, -0.6098]) grad: 3.0 6.0 tensor([2.4101, 0.8034, 0.2678]) progress: 22 0.017927465960383415 grad: 1.0 2.0 tensor([0.1367, 0.1367, 0.1367]) grad: 2.0 4.0 tensor([-2.4124, -1.2062, -0.6031]) grad: 3.0 6.0 tensor([2.3720, 0.7907, 0.2636]) progress: 23 0.01736525259912014 grad: 1.0 2.0 tensor([0.1444, 0.1444, 0.1444]) grad: 2.0 4.0 tensor([-2.3867, -1.1933, -0.5967]) grad: 3.0 6.0 tensor([2.3354, 0.7785, 0.2595]) progress: 24 0.016833148896694183 grad: 1.0 2.0 tensor([0.1518, 0.1518, 0.1518]) grad: 2.0 4.0 tensor([-2.3619, -1.1810, -0.5905]) grad: 3.0 6.0 tensor([2.3001, 0.7667, 0.2556]) progress: 25 0.01632905937731266 grad: 1.0 2.0 tensor([0.1589, 0.1589, 0.1589]) grad: 2.0 4.0 tensor([-2.3380, -1.1690, -0.5845]) grad: 3.0 6.0 tensor([2.2662, 0.7554, 0.2518]) progress: 26 0.01585075818002224 grad: 1.0 2.0 tensor([0.1657, 0.1657, 0.1657]) grad: 2.0 4.0 tensor([-2.3151, -1.1575, -0.5788]) grad: 3.0 6.0 tensor([2.2336, 0.7445, 0.2482]) progress: 27 0.015397666022181511 grad: 1.0 2.0 tensor([0.1723, 0.1723, 0.1723]) grad: 2.0 4.0 tensor([-2.2929, -1.1465, -0.5732]) grad: 3.0 6.0 tensor([2.2022, 0.7341, 0.2447]) progress: 28 0.014967591501772404 grad: 1.0 2.0 tensor([0.1786, 0.1786, 0.1786]) grad: 2.0 4.0 tensor([-2.2716, -1.1358, -0.5679]) grad: 3.0 6.0 tensor([2.1719, 0.7240, 0.2413]) progress: 29 0.014559715054929256 grad: 1.0 2.0 tensor([0.1846, 0.1846, 0.1846]) grad: 2.0 4.0 tensor([-2.2511, -1.1255, -0.5628]) grad: 3.0 6.0 tensor([2.1429, 0.7143, 0.2381]) progress: 30 0.014172340743243694 grad: 1.0 2.0 tensor([0.1904, 0.1904, 0.1904]) grad: 2.0 4.0 tensor([-2.2313, -1.1157, -0.5578]) grad: 3.0 6.0 tensor([2.1149, 0.7050, 0.2350]) progress: 31 0.013804304413497448 grad: 1.0 2.0 tensor([0.1960, 0.1960, 0.1960]) grad: 2.0 4.0 tensor([-2.2123, -1.1061, -0.5531]) grad: 3.0 6.0 tensor([2.0879, 0.6960, 0.2320]) progress: 32 0.013455045409500599 grad: 1.0 2.0 tensor([0.2014, 0.2014, 0.2014]) grad: 2.0 4.0 tensor([-2.1939, -1.0970, -0.5485]) grad: 3.0 6.0 tensor([2.0620, 0.6873, 0.2291]) progress: 33 0.013122711330652237 grad: 1.0 2.0 tensor([0.2065, 0.2065, 0.2065]) grad: 2.0 4.0 tensor([-2.1763, -1.0881, -0.5441]) grad: 3.0 6.0 tensor([2.0370, 0.6790, 0.2263]) progress: 34 0.01280694268643856 grad: 1.0 2.0 tensor([0.2114, 0.2114, 0.2114]) grad: 2.0 4.0 tensor([-2.1592, -1.0796, -0.5398]) grad: 3.0 6.0 tensor([2.0130, 0.6710, 0.2237]) progress: 35 0.012506747618317604 grad: 1.0 2.0 tensor([0.2162, 0.2162, 0.2162]) grad: 2.0 4.0 tensor([-2.1428, -1.0714, -0.5357]) grad: 3.0 6.0 tensor([1.9899, 0.6633, 0.2211]) progress: 36 0.012220758944749832 grad: 1.0 2.0 tensor([0.2207, 0.2207, 0.2207]) grad: 2.0 4.0 tensor([-2.1270, -1.0635, -0.5317]) grad: 3.0 6.0 tensor([1.9676, 0.6559, 0.2186]) progress: 37 0.01194891706109047 grad: 1.0 2.0 tensor([0.2251, 0.2251, 0.2251]) grad: 2.0 4.0 tensor([-2.1118, -1.0559, -0.5279]) grad: 3.0 6.0 tensor([1.9462, 0.6487, 0.2162]) progress: 38 0.011689926497638226 grad: 1.0 2.0 tensor([0.2292, 0.2292, 0.2292]) grad: 2.0 4.0 tensor([-2.0971, -1.0485, -0.5243]) grad: 3.0 6.0 tensor([1.9255, 0.6418, 0.2139]) progress: 39 0.01144315768033266 grad: 1.0 2.0 tensor([0.2333, 0.2333, 0.2333]) grad: 2.0 4.0 tensor([-2.0829, -1.0414, -0.5207]) grad: 3.0 6.0 tensor([1.9057, 0.6352, 0.2117]) progress: 40 0.011208509095013142 grad: 1.0 2.0 tensor([0.2371, 0.2371, 0.2371]) grad: 2.0 4.0 tensor([-2.0693, -1.0346, -0.5173]) grad: 3.0 6.0 tensor([1.8865, 0.6288, 0.2096]) progress: 41 0.0109840864315629 grad: 1.0 2.0 tensor([0.2408, 0.2408, 0.2408]) grad: 2.0 4.0 tensor([-2.0561, -1.0280, -0.5140]) grad: 3.0 6.0 tensor([1.8681, 0.6227, 0.2076]) progress: 42 0.010770938359200954 grad: 1.0 2.0 tensor([0.2444, 0.2444, 0.2444]) grad: 2.0 4.0 tensor([-2.0434, -1.0217, -0.5108]) grad: 3.0 6.0 tensor([1.8503, 0.6168, 0.2056]) progress: 43 0.010566935874521732 grad: 1.0 2.0 tensor([0.2478, 0.2478, 0.2478]) grad: 2.0 4.0 tensor([-2.0312, -1.0156, -0.5078]) grad: 3.0 6.0 tensor([1.8332, 0.6111, 0.2037]) progress: 44 0.010372749529778957 grad: 1.0 2.0 tensor([0.2510, 0.2510, 0.2510]) grad: 2.0 4.0 tensor([-2.0194, -1.0097, -0.5048]) grad: 3.0 6.0 tensor([1.8168, 0.6056, 0.2019]) progress: 45 0.010187389329075813 grad: 1.0 2.0 tensor([0.2542, 0.2542, 0.2542]) grad: 2.0 4.0 tensor([-2.0080, -1.0040, -0.5020]) grad: 3.0 6.0 tensor([1.8009, 0.6003, 0.2001]) progress: 46 0.010010283440351486 grad: 1.0 2.0 tensor([0.2572, 0.2572, 0.2572]) grad: 2.0 4.0 tensor([-1.9970, -0.9985, -0.4992]) grad: 3.0 6.0 tensor([1.7856, 0.5952, 0.1984]) progress: 47 0.00984097272157669 grad: 1.0 2.0 tensor([0.2600, 0.2600, 0.2600]) grad: 2.0 4.0 tensor([-1.9864, -0.9932, -0.4966]) grad: 3.0 6.0 tensor([1.7709, 0.5903, 0.1968]) progress: 48 0.009679674170911312 grad: 1.0 2.0 tensor([0.2628, 0.2628, 0.2628]) grad: 2.0 4.0 tensor([-1.9762, -0.9881, -0.4940]) grad: 3.0 6.0 tensor([1.7568, 0.5856, 0.1952]) progress: 49 0.009525291621685028 grad: 1.0 2.0 tensor([0.2655, 0.2655, 0.2655]) grad: 2.0 4.0 tensor([-1.9663, -0.9832, -0.4916]) grad: 3.0 6.0 tensor([1.7431, 0.5810, 0.1937]) progress: 50 0.00937769003212452 grad: 1.0 2.0 tensor([0.2680, 0.2680, 0.2680]) grad: 2.0 4.0 tensor([-1.9568, -0.9784, -0.4892]) grad: 3.0 6.0 tensor([1.7299, 0.5766, 0.1922]) progress: 51 0.009236648678779602 grad: 1.0 2.0 tensor([0.2704, 0.2704, 0.2704]) grad: 2.0 4.0 tensor([-1.9476, -0.9738, -0.4869]) grad: 3.0 6.0 tensor([1.7172, 0.5724, 0.1908]) progress: 52 0.00910158734768629 grad: 1.0 2.0 tensor([0.2728, 0.2728, 0.2728]) grad: 2.0 4.0 tensor([-1.9387, -0.9694, -0.4847]) grad: 3.0 6.0 tensor([1.7050, 0.5683, 0.1894]) progress: 53 0.00897257961332798 grad: 1.0 2.0 tensor([0.2750, 0.2750, 0.2750]) grad: 2.0 4.0 tensor([-1.9301, -0.9651, -0.4825]) grad: 3.0 6.0 tensor([1.6932, 0.5644, 0.1881]) progress: 54 0.008848887868225574 grad: 1.0 2.0 tensor([0.2771, 0.2771, 0.2771]) grad: 2.0 4.0 tensor([-1.9219, -0.9609, -0.4805]) grad: 3.0 6.0 tensor([1.6819, 0.5606, 0.1869]) progress: 55 0.008730598725378513 grad: 1.0 2.0 tensor([0.2792, 0.2792, 0.2792]) grad: 2.0 4.0 tensor([-1.9139, -0.9569, -0.4785]) grad: 3.0 6.0 tensor([1.6709, 0.5570, 0.1857]) progress: 56 0.00861735362559557 grad: 1.0 2.0 tensor([0.2811, 0.2811, 0.2811]) grad: 2.0 4.0 tensor([-1.9062, -0.9531, -0.4765]) grad: 3.0 6.0 tensor([1.6604, 0.5535, 0.1845]) progress: 57 0.008508718572556973 grad: 1.0 2.0 tensor([0.2830, 0.2830, 0.2830]) grad: 2.0 4.0 tensor([-1.8987, -0.9493, -0.4747]) grad: 3.0 6.0 tensor([1.6502, 0.5501, 0.1834]) progress: 58 0.008404706604778767 grad: 1.0 2.0 tensor([0.2848, 0.2848, 0.2848]) grad: 2.0 4.0 tensor([-1.8915, -0.9457, -0.4729]) grad: 3.0 6.0 tensor([1.6404, 0.5468, 0.1823]) progress: 59 0.008305158466100693 grad: 1.0 2.0 tensor([0.2865, 0.2865, 0.2865]) grad: 2.0 4.0 tensor([-1.8845, -0.9423, -0.4711]) grad: 3.0 6.0 tensor([1.6309, 0.5436, 0.1812]) progress: 60 0.00820931326597929 grad: 1.0 2.0 tensor([0.2882, 0.2882, 0.2882]) grad: 2.0 4.0 tensor([-1.8778, -0.9389, -0.4694]) grad: 3.0 6.0 tensor([1.6218, 0.5406, 0.1802]) progress: 61 0.008117804303765297 grad: 1.0 2.0 tensor([0.2898, 0.2898, 0.2898]) grad: 2.0 4.0 tensor([-1.8713, -0.9356, -0.4678]) grad: 3.0 6.0 tensor([1.6130, 0.5377, 0.1792]) progress: 62 0.008029798977077007 grad: 1.0 2.0 tensor([0.2913, 0.2913, 0.2913]) grad: 2.0 4.0 tensor([-1.8650, -0.9325, -0.4662]) grad: 3.0 6.0 tensor([1.6045, 0.5348, 0.1783]) progress: 63 0.007945418357849121 grad: 1.0 2.0 tensor([0.2927, 0.2927, 0.2927]) grad: 2.0 4.0 tensor([-1.8589, -0.9294, -0.4647]) grad: 3.0 6.0 tensor([1.5962, 0.5321, 0.1774]) progress: 64 0.007864190265536308 grad: 1.0 2.0 tensor([0.2941, 0.2941, 0.2941]) grad: 2.0 4.0 tensor([-1.8530, -0.9265, -0.4632]) grad: 3.0 6.0 tensor([1.5884, 0.5295, 0.1765]) progress: 65 0.007786744274199009 grad: 1.0 2.0 tensor([0.2954, 0.2954, 0.2954]) grad: 2.0 4.0 tensor([-1.8473, -0.9236, -0.4618]) grad: 3.0 6.0 tensor([1.5807, 0.5269, 0.1756]) progress: 66 0.007711691781878471 grad: 1.0 2.0 tensor([0.2967, 0.2967, 0.2967]) grad: 2.0 4.0 tensor([-1.8417, -0.9209, -0.4604]) grad: 3.0 6.0 tensor([1.5733, 0.5244, 0.1748]) progress: 67 0.007640169933438301 grad: 1.0 2.0 tensor([0.2979, 0.2979, 0.2979]) grad: 2.0 4.0 tensor([-1.8364, -0.9182, -0.4591]) grad: 3.0 6.0 tensor([1.5662, 0.5221, 0.1740]) progress: 68 0.007570972666144371 grad: 1.0 2.0 tensor([0.2991, 0.2991, 0.2991]) grad: 2.0 4.0 tensor([-1.8312, -0.9156, -0.4578]) grad: 3.0 6.0 tensor([1.5593, 0.5198, 0.1733]) progress: 69 0.007504733745008707 grad: 1.0 2.0 tensor([0.3002, 0.3002, 0.3002]) grad: 2.0 4.0 tensor([-1.8262, -0.9131, -0.4566]) grad: 3.0 6.0 tensor([1.5527, 0.5176, 0.1725]) progress: 70 0.007440924644470215 grad: 1.0 2.0 tensor([0.3012, 0.3012, 0.3012]) grad: 2.0 4.0 tensor([-1.8214, -0.9107, -0.4553]) grad: 3.0 6.0 tensor([1.5463, 0.5154, 0.1718]) progress: 71 0.007379599846899509 grad: 1.0 2.0 tensor([0.3022, 0.3022, 0.3022]) grad: 2.0 4.0 tensor([-1.8167, -0.9083, -0.4542]) grad: 3.0 6.0 tensor([1.5401, 0.5134, 0.1711]) progress: 72 0.007320486940443516 grad: 1.0 2.0 tensor([0.3032, 0.3032, 0.3032]) grad: 2.0 4.0 tensor([-1.8121, -0.9060, -0.4530]) grad: 3.0 6.0 tensor([1.5341, 0.5114, 0.1705]) progress: 73 0.007263725157827139 grad: 1.0 2.0 tensor([0.3041, 0.3041, 0.3041]) grad: 2.0 4.0 tensor([-1.8077, -0.9038, -0.4519]) grad: 3.0 6.0 tensor([1.5283, 0.5094, 0.1698]) progress: 74 0.007209045812487602 grad: 1.0 2.0 tensor([0.3050, 0.3050, 0.3050]) grad: 2.0 4.0 tensor([-1.8034, -0.9017, -0.4508]) grad: 3.0 6.0 tensor([1.5227, 0.5076, 0.1692]) progress: 75 0.007156429346650839 grad: 1.0 2.0 tensor([0.3058, 0.3058, 0.3058]) grad: 2.0 4.0 tensor([-1.7992, -0.8996, -0.4498]) grad: 3.0 6.0 tensor([1.5173, 0.5058, 0.1686]) progress: 76 0.007105532102286816 grad: 1.0 2.0 tensor([0.3066, 0.3066, 0.3066]) grad: 2.0 4.0 tensor([-1.7952, -0.8976, -0.4488]) grad: 3.0 6.0 tensor([1.5121, 0.5040, 0.1680]) progress: 77 0.00705681974068284 grad: 1.0 2.0 tensor([0.3073, 0.3073, 0.3073]) grad: 2.0 4.0 tensor([-1.7913, -0.8956, -0.4478]) grad: 3.0 6.0 tensor([1.5070, 0.5023, 0.1674]) progress: 78 0.007009552326053381 grad: 1.0 2.0 tensor([0.3081, 0.3081, 0.3081]) grad: 2.0 4.0 tensor([-1.7875, -0.8937, -0.4469]) grad: 3.0 6.0 tensor([1.5021, 0.5007, 0.1669]) progress: 79 0.006964194122701883 grad: 1.0 2.0 tensor([0.3087, 0.3087, 0.3087]) grad: 2.0 4.0 tensor([-1.7838, -0.8919, -0.4459]) grad: 3.0 6.0 tensor([1.4974, 0.4991, 0.1664]) progress: 80 0.006920332089066505 grad: 1.0 2.0 tensor([0.3094, 0.3094, 0.3094]) grad: 2.0 4.0 tensor([-1.7802, -0.8901, -0.4450]) grad: 3.0 6.0 tensor([1.4928, 0.4976, 0.1659]) progress: 81 0.006878111511468887 grad: 1.0 2.0 tensor([0.3100, 0.3100, 0.3100]) grad: 2.0 4.0 tensor([-1.7767, -0.8883, -0.4442]) grad: 3.0 6.0 tensor([1.4884, 0.4961, 0.1654]) progress: 82 0.006837360095232725 grad: 1.0 2.0 tensor([0.3106, 0.3106, 0.3106]) grad: 2.0 4.0 tensor([-1.7733, -0.8867, -0.4433]) grad: 3.0 6.0 tensor([1.4841, 0.4947, 0.1649]) progress: 83 0.006797831039875746 grad: 1.0 2.0 tensor([0.3111, 0.3111, 0.3111]) grad: 2.0 4.0 tensor([-1.7700, -0.8850, -0.4425]) grad: 3.0 6.0 tensor([1.4800, 0.4933, 0.1644]) progress: 84 0.006760062649846077 grad: 1.0 2.0 tensor([0.3117, 0.3117, 0.3117]) grad: 2.0 4.0 tensor([-1.7668, -0.8834, -0.4417]) grad: 3.0 6.0 tensor([1.4759, 0.4920, 0.1640]) progress: 85 0.006723103579133749 grad: 1.0 2.0 tensor([0.3122, 0.3122, 0.3122]) grad: 2.0 4.0 tensor([-1.7637, -0.8818, -0.4409]) grad: 3.0 6.0 tensor([1.4720, 0.4907, 0.1636]) progress: 86 0.00668772729113698 grad: 1.0 2.0 tensor([0.3127, 0.3127, 0.3127]) grad: 2.0 4.0 tensor([-1.7607, -0.8803, -0.4402]) grad: 3.0 6.0 tensor([1.4682, 0.4894, 0.1631]) progress: 87 0.006653300020843744 grad: 1.0 2.0 tensor([0.3131, 0.3131, 0.3131]) grad: 2.0 4.0 tensor([-1.7577, -0.8789, -0.4394]) grad: 3.0 6.0 tensor([1.4646, 0.4882, 0.1627]) progress: 88 0.0066203586757183075 grad: 1.0 2.0 tensor([0.3135, 0.3135, 0.3135]) grad: 2.0 4.0 tensor([-1.7548, -0.8774, -0.4387]) grad: 3.0 6.0 tensor([1.4610, 0.4870, 0.1623]) progress: 89 0.0065881176851689816 grad: 1.0 2.0 tensor([0.3139, 0.3139, 0.3139]) grad: 2.0 4.0 tensor([-1.7520, -0.8760, -0.4380]) grad: 3.0 6.0 tensor([1.4576, 0.4859, 0.1620]) progress: 90 0.0065572685562074184 grad: 1.0 2.0 tensor([0.3143, 0.3143, 0.3143]) grad: 2.0 4.0 tensor([-1.7493, -0.8747, -0.4373]) grad: 3.0 6.0 tensor([1.4542, 0.4847, 0.1616]) progress: 91 0.0065271081402897835 grad: 1.0 2.0 tensor([0.3147, 0.3147, 0.3147]) grad: 2.0 4.0 tensor([-1.7466, -0.8733, -0.4367]) grad: 3.0 6.0 tensor([1.4510, 0.4837, 0.1612]) progress: 92 0.00649801641702652 grad: 1.0 2.0 tensor([0.3150, 0.3150, 0.3150]) grad: 2.0 4.0 tensor([-1.7441, -0.8720, -0.4360]) grad: 3.0 6.0 tensor([1.4478, 0.4826, 0.1609]) progress: 93 0.0064699104987084866 grad: 1.0 2.0 tensor([0.3153, 0.3153, 0.3153]) grad: 2.0 4.0 tensor([-1.7415, -0.8708, -0.4354]) grad: 3.0 6.0 tensor([1.4448, 0.4816, 0.1605]) progress: 94 0.006442630663514137 grad: 1.0 2.0 tensor([0.3156, 0.3156, 0.3156]) grad: 2.0 4.0 tensor([-1.7391, -0.8695, -0.4348]) grad: 3.0 6.0 tensor([1.4418, 0.4806, 0.1602]) progress: 95 0.006416172254830599 grad: 1.0 2.0 tensor([0.3159, 0.3159, 0.3159]) grad: 2.0 4.0 tensor([-1.7366, -0.8683, -0.4342]) grad: 3.0 6.0 tensor([1.4389, 0.4796, 0.1599]) progress: 96 0.006390606984496117 grad: 1.0 2.0 tensor([0.3161, 0.3161, 0.3161]) grad: 2.0 4.0 tensor([-1.7343, -0.8671, -0.4336]) grad: 3.0 6.0 tensor([1.4361, 0.4787, 0.1596]) progress: 97 0.0063657015562057495 grad: 1.0 2.0 tensor([0.3164, 0.3164, 0.3164]) grad: 2.0 4.0 tensor([-1.7320, -0.8660, -0.4330]) grad: 3.0 6.0 tensor([1.4334, 0.4778, 0.1593]) progress: 98 0.0063416799530386925 grad: 1.0 2.0 tensor([0.3166, 0.3166, 0.3166]) grad: 2.0 4.0 tensor([-1.7297, -0.8649, -0.4324]) grad: 3.0 6.0 tensor([1.4308, 0.4769, 0.1590]) progress: 99 0.00631808303296566 predict (after tranining) 4 8.544171333312988
損失值隨著迭代次數(shù)的增加呈遞減趨勢,如下圖所示:
可以看出:x=4時的預測值約為8.5,與真實值8有所差距,可通過提高迭代次數(shù)或者調(diào)整學習率、初始參數(shù)等方法來減小差距。
讀到這里,這篇“PyTorch梯度下降反向傳播實例分析”文章已經(jīng)介紹完畢,想要掌握這篇文章的知識點還需要大家自己動手實踐使用過才能領會,如果想了解更多相關內(nèi)容的文章,歡迎關注億速云行業(yè)資訊頻道。
免責聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權請聯(lián)系站長郵箱:is@yisu.com進行舉報,并提供相關證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權內(nèi)容。