您好,登錄后才能下訂單哦!
本篇內(nèi)容主要講解“Pytorch的使用技巧有哪些”,感興趣的朋友不妨來看看。本文介紹的方法操作簡單快捷,實用性強。下面就讓小編來帶大家學(xué)習(xí)“Pytorch的使用技巧有哪些”吧!
訓(xùn)練模型,最常看的指標(biāo)就是 Loss。我們可以根據(jù) Loss 的收斂情況,初步判斷模型訓(xùn)練的好壞。
如果,Loss 值突然上升了,那說明訓(xùn)練有問題,需要檢查數(shù)據(jù)和代碼。
如果,Loss 值趨于穩(wěn)定,那說明訓(xùn)練完畢了。
觀察 Loss 情況,最直觀的方法,就是繪制 Loss 曲線圖。
通過繪圖,我們可以很清晰的看到,左圖還有收斂空間,而右圖已經(jīng)完全收斂。
通過 Loss 曲線,我們可以分析模型訓(xùn)練的好壞,模型是否訓(xùn)練完成,起到一個很好的“監(jiān)控”作用。
繪制 Loss 曲線圖,第一步就是需要保存訓(xùn)練過程中的 Loss 值。
一個最簡單的方法是使用,sys.stdout 標(biāo)準(zhǔn)輸出重定向,簡單好用,實乃“煉丹”必備“良寶”。
import os import sys class Logger(): def __init__(self, filename="log.txt"): self.terminal = sys.stdout self.log = open(filename, "w") def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): pass sys.stdout = Logger() print("Jack Cui") print("https://cuijiahua.com") print("https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA")
代碼很簡單,創(chuàng)建一個 log.py 文件,自己寫一個 Logger 類,并采用 sys.stdout 重定向輸出。
在 Terminal 中,不僅可以使用 print 打印結(jié)果,同時也會將結(jié)果保存到 log.txt 文件中。
運行 log.py,打印 print 內(nèi)容的同時,也將內(nèi)容寫入了 log.txt 文件中。
使用這個代碼,就可以在打印 Loss 的同時,將結(jié)果保存到指定的 txt 中,比如保存上篇文章訓(xùn)練 UNet 的 Loss。
Matplotlib 是一個 Python 的繪圖庫,簡單好用。
簡單幾行命令,就可以繪制曲線圖、散點圖、條形圖、直方圖、餅圖等等。
在深度學(xué)習(xí)中,一般就是繪制曲線圖,比如 Loss 曲線、Acc 曲線。
舉一個,簡單的例子。
使用 sys.stdout 保存的 train_loss.txt,繪制 Loss 曲線。
train_loss.txt 下載地址:點擊查看
思路非常簡單,讀取 txt 內(nèi)容,解析 txt 內(nèi)容,使用 Matplotlib 繪制曲線。
import matplotlib.pyplot as plt # Jupyter notebook 中開啟 # %matplotlib inline with open('train_loss.txt', 'r') as f: train_loss = f.readlines() train_loss = list(map(lambda x:float(x.strip()), train_loss)) x = range(len(train_loss)) y = train_loss plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5) plt.xlabel('Epoch') plt.ylabel('Loss Value') plt.legend() plt.show()
指定 x 和 y 對應(yīng)的值,就可以繪制。
是不是很簡單?
說到保存日志,那不得不提 Python 的內(nèi)置標(biāo)準(zhǔn)模塊 Logging,它主要用于輸出運行日志,可以設(shè)置輸出日志的等級、日志保存路徑、日志文件回滾等,同時,我們也可以設(shè)置日志的輸出格式。
import logging def get_logger(LEVEL, log_file = None): head = '[%(asctime)-15s] [%(levelname)s] %(message)s' if LEVEL == 'info': logging.basicConfig(level=logging.INFO, format=head) elif LEVEL == 'debug': logging.basicConfig(level=logging.DEBUG, format=head) logger = logging.getLogger() if log_file != None: fh = logging.FileHandler(log_file) logger.addHandler(fh) return logger logger = get_logger('info') logger.info('Jack Cui') logger.info('https://cuijiahua.com') logger.info('https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA')
只需要幾行代碼,進(jìn)行一個簡單的封裝使用。使用函數(shù) get_logger 創(chuàng)建一個級別為 info 的 logger,如果指定 log_file,則會對日志進(jìn)行保存。
logging 默認(rèn)支持的日志一共有 5 個等級:
日志級別等級 CRITICAL > ERROR > WARNING > INFO > DEBUG。
默認(rèn)的日志級別設(shè)置為 WARNING,也就是說如果不指定日志級別,只會顯示大于等于 WARNING 級別的日志。
例如:
import logging logging.debug("debug_msg") logging.info("info_msg") logging.warning("warning_msg") logging.error("error_msg") logging.critical("critical_msg")
運行結(jié)果:
WARNING:root:warning_msg ERROR:root:error_msg CRITICAL:root:critical_msg
可以看到 info 和 debug 級別的日志不會輸出,默認(rèn)的日志格式也比較簡單。
默認(rèn)的日志格式為日志級別:Logger名稱:用戶輸出消息
當(dāng)然,我們可以通過,logging.basicConfig 的 format 參數(shù),設(shè)置日志格式。
字段有很多,可謂應(yīng)有盡有,足以滿足我們定制化的需求。
上文介紹的“法寶”,并非針對深度學(xué)習(xí)“煉丹”使用的工具。
而 TensorboardX 則不同,它是專門用于深度學(xué)習(xí)“煉丹”的高級“法寶”。
早些時候,很多人更喜歡用 Tensorflow 的原因之一,就是 Tensorflow 框架有個一個很好的可視化工具 Tensorboard。
Pytorch 要想使用 Tensorboard 配置起來費勁兒不說,還有很多 Bug。
Pytorch 1.1.0 版本發(fā)布后,打破了這個局面,TensorBoard 成為了 Pytorch 的正式可用組件。
在 Pytorch 中,這個可視化工具叫做 TensorBoardX,其實就是針對 Tensorboard 的一個封裝,使得 PyTorch 用戶也能夠調(diào)用 Tensorboard。
TensorboardX 安裝也非常簡單,使用 pip 即可安裝。
pip install tensorboardX
tensorboardX 使用也很簡單,編寫如下代碼。
from tensorboardX import SummaryWriter # 創(chuàng)建 writer1 對象 # log 會保存到 runs/exp 文件夾中 writer1 = SummaryWriter('runs/exp') # 使用默認(rèn)參數(shù)創(chuàng)建 writer2 對象 # log 會保存到 runs/日期_用戶名 格式的文件夾中 writer2 = SummaryWriter() # 使用 commet 參數(shù),創(chuàng)建 writer3 對象 # log 會保存到 runs/日期_用戶名_resnet 格式的文件中 writer3 = SummaryWriter(comment='_resnet')
使用的時候,創(chuàng)建一個 SummaryWriter 對象即可,以上展示了三種初始化 SummaryWriter 的方法:
提供一個路徑,將使用該路徑來保存日志
無參數(shù),默認(rèn)將使用 runs/日期_用戶名 路徑來保存日志
提供一個 comment 參數(shù),將使用 runs/日期_用戶名+comment 路徑來保存日志
運行結(jié)果:
有了 writer 我們就可以往日志里寫入數(shù)字、圖片、甚至聲音等數(shù)據(jù)。
這個是最簡單的,使用 add_scalar 方法來記錄數(shù)字常量。
add_scalar(tag, scalar_value, global_step=None, walltime=None)
總共 4 個參數(shù)。
tag (string): 數(shù)據(jù)名稱,不同名稱的數(shù)據(jù)使用不同曲線展示
scalar_value (float): 數(shù)字常量值
global_step (int, optional): 訓(xùn)練的 step
walltime (float, optional): 記錄發(fā)生的時間,默認(rèn)為 time.time()
需要注意,這里的 scalar_value 一定是 float 類型,如果是 PyTorch scalar tensor,則需要調(diào)用 .item() 方法獲取其數(shù)值。我們一般會使用 add_scalar 方法來記錄訓(xùn)練過程的 loss、accuracy、learning rate 等數(shù)值的變化,直觀地監(jiān)控訓(xùn)練過程。
運行如下代碼:
from tensorboardX import SummaryWriter writer = SummaryWriter('runs/scalar_example') for i in range(10): writer.add_scalar('quadratic', i**2, global_step=i) writer.add_scalar('exponential', 2**i, global_step=i) writer.close()
通過 add_scalar 往日志里寫入數(shù)字,日志保存到 runs/scalar_example中,writer 用完要記得 close,否則無法保存數(shù)據(jù)。
在 cmd 中使用如下命令:
tensorboard --logdir=runs/scalar_example --port=8088
指定日志地址,使用端口號,在瀏覽器中,就可以使用如下地址,打開 Tensorboad。
http://localhost:8088/
省去了我們自己寫代碼可視化的麻煩。
使用 add_image 方法來記錄單個圖像數(shù)據(jù)。注意,該方法需要 pillow 庫的支持。
add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')
參數(shù):
tag (string):數(shù)據(jù)名稱
img_tensor (torch.Tensor / numpy.array):圖像數(shù)據(jù)
global_step (int, optional):訓(xùn)練的 step
walltime (float, optional):記錄發(fā)生的時間,默認(rèn)為 time.time()
dataformats (string, optional):圖像數(shù)據(jù)的格式,默認(rèn)為 'CHW',即 Channel x Height x Width,還可以是 'CHW'、'HWC' 或 'HW' 等
我們一般會使用 add_image 來實時觀察生成式模型的生成效果,或者可視化分割、目標(biāo)檢測的結(jié)果,幫助調(diào)試模型。
from tensorboardX import SummaryWriter from urllib.request import urlretrieve import cv2 urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/0.png',filename = '1.jpg') urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/1.png',filename = '2.jpg') urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/2.png',filename = '3.jpg') writer = SummaryWriter('runs/image_example') for i in range(1, 4): writer.add_image('UNet_Seg', cv2.cvtColor(cv2.imread('{}.jpg'.format(i)), cv2.COLOR_BGR2RGB), global_step=i, dataformats='HWC') writer.close()
代碼就是下載上篇文章數(shù)據(jù)集里的三張圖片,然后使用 Tensorboard 可視化處理來,使用 8088 端口開打 Tensorboard:
tensorboard --logdir=runs/image_example --port=8088
運行結(jié)果:
試想一下,一邊訓(xùn)練,一邊輸出圖片結(jié)果,是不是很酸爽呢?
Tensorboard 中常用的 Scalar 和 Image,直方圖、運行圖、嵌入向量等,可以查看官方手冊進(jìn)行學(xué)習(xí),方法都是類似的,簡單好用。
到此,相信大家對“Pytorch的使用技巧有哪些”有了更深的了解,不妨來實際操作一番吧!這里是億速云網(wǎng)站,更多相關(guān)內(nèi)容可以進(jìn)入相關(guān)頻道進(jìn)行查詢,關(guān)注我們,繼續(xù)學(xué)習(xí)!
免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。