溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務(wù)條款》

Pytorch的使用技巧有哪些

發(fā)布時間:2021-12-16 09:54:39 來源:億速云 閱讀:159 作者:iii 欄目:大數(shù)據(jù)

本篇內(nèi)容主要講解“Pytorch的使用技巧有哪些”,感興趣的朋友不妨來看看。本文介紹的方法操作簡單快捷,實用性強。下面就讓小編來帶大家學(xué)習(xí)“Pytorch的使用技巧有哪些”吧!

一、初級“法寶”,sys.stdout

訓(xùn)練模型,最常看的指標(biāo)就是 Loss。我們可以根據(jù) Loss 的收斂情況,初步判斷模型訓(xùn)練的好壞。

如果,Loss 值突然上升了,那說明訓(xùn)練有問題,需要檢查數(shù)據(jù)和代碼。

如果,Loss 值趨于穩(wěn)定,那說明訓(xùn)練完畢了。

觀察 Loss 情況,最直觀的方法,就是繪制 Loss 曲線圖。

Pytorch的使用技巧有哪些

通過繪圖,我們可以很清晰的看到,左圖還有收斂空間,而右圖已經(jīng)完全收斂。

通過 Loss 曲線,我們可以分析模型訓(xùn)練的好壞,模型是否訓(xùn)練完成,起到一個很好的“監(jiān)控”作用。

繪制 Loss 曲線圖,第一步就是需要保存訓(xùn)練過程中的 Loss 值。

一個最簡單的方法是使用,sys.stdout 標(biāo)準(zhǔn)輸出重定向,簡單好用,實乃“煉丹”必備“良寶”。

import os
import sys
class Logger():
    def __init__(self, filename="log.txt"):
        self.terminal = sys.stdout
        self.log = open(filename, "w")
 
    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
 
    def flush(self):
        pass
 
sys.stdout = Logger()
 
print("Jack Cui")
print("https://cuijiahua.com")
print("https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA")

代碼很簡單,創(chuàng)建一個 log.py 文件,自己寫一個 Logger 類,并采用 sys.stdout 重定向輸出。

在 Terminal 中,不僅可以使用 print 打印結(jié)果,同時也會將結(jié)果保存到 log.txt 文件中。

Pytorch的使用技巧有哪些

運行 log.py,打印 print 內(nèi)容的同時,也將內(nèi)容寫入了 log.txt 文件中。

使用這個代碼,就可以在打印 Loss 的同時,將結(jié)果保存到指定的 txt 中,比如保存上篇文章訓(xùn)練 UNet 的 Loss。

Pytorch的使用技巧有哪些

二、中級“法寶”,matplotlib

Matplotlib 是一個 Python 的繪圖庫,簡單好用。

簡單幾行命令,就可以繪制曲線圖、散點圖、條形圖、直方圖、餅圖等等。

在深度學(xué)習(xí)中,一般就是繪制曲線圖,比如 Loss 曲線、Acc 曲線。

舉一個,簡單的例子。

使用 sys.stdout 保存的 train_loss.txt,繪制 Loss 曲線。

Pytorch的使用技巧有哪些

train_loss.txt 下載地址:點擊查看

思路非常簡單,讀取 txt 內(nèi)容,解析 txt 內(nèi)容,使用 Matplotlib 繪制曲線。

import matplotlib.pyplot as plt
# Jupyter notebook 中開啟
# %matplotlib inline
with open('train_loss.txt', 'r') as f:
    train_loss = f.readlines()
    train_loss = list(map(lambda x:float(x.strip()), train_loss))
x = range(len(train_loss))
y = train_loss
plt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5)
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.show()

指定 x 和 y 對應(yīng)的值,就可以繪制。

Pytorch的使用技巧有哪些

是不是很簡單?

三、中級“法寶”,Logging

說到保存日志,那不得不提 Python 的內(nèi)置標(biāo)準(zhǔn)模塊 Logging,它主要用于輸出運行日志,可以設(shè)置輸出日志的等級、日志保存路徑、日志文件回滾等,同時,我們也可以設(shè)置日志的輸出格式。

import logging

def get_logger(LEVEL, log_file = None):
    head = '[%(asctime)-15s] [%(levelname)s] %(message)s'
    if LEVEL == 'info':
        logging.basicConfig(level=logging.INFO, format=head)
    elif LEVEL == 'debug':
        logging.basicConfig(level=logging.DEBUG, format=head)
    logger = logging.getLogger()
    if log_file != None:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)
    return logger

logger = get_logger('info')

logger.info('Jack Cui')
logger.info('https://cuijiahua.com')
logger.info('https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA')

只需要幾行代碼,進(jìn)行一個簡單的封裝使用。使用函數(shù) get_logger 創(chuàng)建一個級別為 info 的 logger,如果指定 log_file,則會對日志進(jìn)行保存。

Pytorch的使用技巧有哪些

logging 默認(rèn)支持的日志一共有 5 個等級:

Pytorch的使用技巧有哪些

日志級別等級 CRITICAL > ERROR > WARNING > INFO > DEBUG。

默認(rèn)的日志級別設(shè)置為 WARNING,也就是說如果不指定日志級別,只會顯示大于等于 WARNING 級別的日志。

例如:

import logging
logging.debug("debug_msg")
logging.info("info_msg")
logging.warning("warning_msg")
logging.error("error_msg")
logging.critical("critical_msg")

運行結(jié)果:

WARNING:root:warning_msg
ERROR:root:error_msg
CRITICAL:root:critical_msg

可以看到 info 和 debug 級別的日志不會輸出,默認(rèn)的日志格式也比較簡單。

默認(rèn)的日志格式為日志級別:Logger名稱:用戶輸出消息

當(dāng)然,我們可以通過,logging.basicConfig 的 format 參數(shù),設(shè)置日志格式。

Pytorch的使用技巧有哪些

字段有很多,可謂應(yīng)有盡有,足以滿足我們定制化的需求。

四、高級“法寶”,TensorboardX

上文介紹的“法寶”,并非針對深度學(xué)習(xí)“煉丹”使用的工具。

而 TensorboardX 則不同,它是專門用于深度學(xué)習(xí)“煉丹”的高級“法寶”。

早些時候,很多人更喜歡用 Tensorflow 的原因之一,就是 Tensorflow 框架有個一個很好的可視化工具 Tensorboard。

Pytorch 要想使用 Tensorboard 配置起來費勁兒不說,還有很多 Bug。

Pytorch 1.1.0 版本發(fā)布后,打破了這個局面,TensorBoard 成為了 Pytorch 的正式可用組件。

在 Pytorch 中,這個可視化工具叫做 TensorBoardX,其實就是針對 Tensorboard 的一個封裝,使得 PyTorch 用戶也能夠調(diào)用 Tensorboard。

TensorboardX 安裝也非常簡單,使用 pip 即可安裝。

pip install tensorboardX

tensorboardX 使用也很簡單,編寫如下代碼。

from tensorboardX import SummaryWriter

# 創(chuàng)建 writer1 對象
# log 會保存到 runs/exp 文件夾中
writer1 = SummaryWriter('runs/exp')

# 使用默認(rèn)參數(shù)創(chuàng)建 writer2 對象
# log 會保存到 runs/日期_用戶名 格式的文件夾中
writer2 = SummaryWriter()

# 使用 commet 參數(shù),創(chuàng)建 writer3 對象
# log 會保存到 runs/日期_用戶名_resnet 格式的文件中
writer3 = SummaryWriter(comment='_resnet')

使用的時候,創(chuàng)建一個 SummaryWriter 對象即可,以上展示了三種初始化 SummaryWriter 的方法:

  • 提供一個路徑,將使用該路徑來保存日志

  • 無參數(shù),默認(rèn)將使用 runs/日期_用戶名 路徑來保存日志

  • 提供一個 comment 參數(shù),將使用 runs/日期_用戶名+comment 路徑來保存日志

運行結(jié)果:

Pytorch的使用技巧有哪些

有了 writer 我們就可以往日志里寫入數(shù)字、圖片、甚至聲音等數(shù)據(jù)。

數(shù)字 (scalar)

這個是最簡單的,使用 add_scalar 方法來記錄數(shù)字常量。

add_scalar(tag, scalar_value, global_step=None, walltime=None)

總共 4 個參數(shù)。

  • tag (string): 數(shù)據(jù)名稱,不同名稱的數(shù)據(jù)使用不同曲線展示

  • scalar_value (float): 數(shù)字常量值

  • global_step (int, optional): 訓(xùn)練的 step

  • walltime (float, optional): 記錄發(fā)生的時間,默認(rèn)為 time.time()

需要注意,這里的 scalar_value 一定是 float 類型,如果是 PyTorch scalar tensor,則需要調(diào)用 .item() 方法獲取其數(shù)值。我們一般會使用 add_scalar 方法來記錄訓(xùn)練過程的 loss、accuracy、learning rate 等數(shù)值的變化,直觀地監(jiān)控訓(xùn)練過程。

運行如下代碼:

from tensorboardX import SummaryWriter    
writer = SummaryWriter('runs/scalar_example')
for i in range(10):
    writer.add_scalar('quadratic', i**2, global_step=i)
    writer.add_scalar('exponential', 2**i, global_step=i)
writer.close()

通過 add_scalar 往日志里寫入數(shù)字,日志保存到 runs/scalar_example中,writer 用完要記得 close,否則無法保存數(shù)據(jù)。

在 cmd 中使用如下命令:

tensorboard --logdir=runs/scalar_example --port=8088

指定日志地址,使用端口號,在瀏覽器中,就可以使用如下地址,打開 Tensorboad。

http://localhost:8088/

省去了我們自己寫代碼可視化的麻煩。

Pytorch的使用技巧有哪些

圖片 (image)

使用 add_image 方法來記錄單個圖像數(shù)據(jù)。注意,該方法需要 pillow 庫的支持。

add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

參數(shù):

  • tag (string):數(shù)據(jù)名稱

  • img_tensor (torch.Tensor / numpy.array):圖像數(shù)據(jù)

  • global_step (int, optional):訓(xùn)練的 step

  • walltime (float, optional):記錄發(fā)生的時間,默認(rèn)為 time.time()

  • dataformats (string, optional):圖像數(shù)據(jù)的格式,默認(rèn)為 'CHW',即 Channel x Height x Width,還可以是 'CHW'、'HWC' 或 'HW' 等

我們一般會使用 add_image 來實時觀察生成式模型的生成效果,或者可視化分割、目標(biāo)檢測的結(jié)果,幫助調(diào)試模型。

from tensorboardX import SummaryWriter
from urllib.request import urlretrieve
import cv2

urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/0.png',filename = '1.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/1.png',filename = '2.jpg')
urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/2.png',filename = '3.jpg')

writer = SummaryWriter('runs/image_example')
for i in range(1, 4):
    writer.add_image('UNet_Seg',
                     cv2.cvtColor(cv2.imread('{}.jpg'.format(i)), cv2.COLOR_BGR2RGB),
                     global_step=i,
                     dataformats='HWC')
writer.close()

代碼就是下載上篇文章數(shù)據(jù)集里的三張圖片,然后使用 Tensorboard 可視化處理來,使用 8088 端口開打 Tensorboard:

tensorboard --logdir=runs/image_example --port=8088

運行結(jié)果:

Pytorch的使用技巧有哪些

試想一下,一邊訓(xùn)練,一邊輸出圖片結(jié)果,是不是很酸爽呢?

Tensorboard 中常用的 Scalar 和 Image,直方圖、運行圖、嵌入向量等,可以查看官方手冊進(jìn)行學(xué)習(xí),方法都是類似的,簡單好用。

到此,相信大家對“Pytorch的使用技巧有哪些”有了更深的了解,不妨來實際操作一番吧!這里是億速云網(wǎng)站,更多相關(guān)內(nèi)容可以進(jìn)入相關(guān)頻道進(jìn)行查詢,關(guān)注我們,繼續(xù)學(xué)習(xí)!

向AI問一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點不代表本網(wǎng)站立場,如果涉及侵權(quán)請聯(lián)系站長郵箱:is@yisu.com進(jìn)行舉報,并提供相關(guān)證據(jù),一經(jīng)查實,將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI