溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊(cè)×
其他方式登錄
點(diǎn)擊 登錄注冊(cè) 即表示同意《億速云用戶服務(wù)條款》

什么是Config和Trainer

發(fā)布時(shí)間:2021-10-12 10:37:47 來(lái)源:億速云 閱讀:194 作者:iii 欄目:編程語(yǔ)言

這篇文章主要講解了“什么是Config和Trainer”,文中的講解內(nèi)容簡(jiǎn)單清晰,易于學(xué)習(xí)與理解,下面請(qǐng)大家跟著小編的思路慢慢深入,一起來(lái)研究和學(xué)習(xí)“什么是Config和Trainer”吧!

代碼結(jié)構(gòu)概覽

核心部分

  • configs:儲(chǔ)存各種網(wǎng)絡(luò)的yaml配置文件

  • datasets:存放數(shù)據(jù)集的地方

  • detectron2:運(yùn)行代碼的核心組件

  • tools:提供了運(yùn)行代碼的入口以及一切可視化的代碼文件。

Tutorial部分

  • demo:顯而易見(jiàn)就是demo

  • docs: 同樣顯而易見(jiàn)。。

  • tests:提供了一些測(cè)試代碼

  • projects:提供了真實(shí)的項(xiàng)目代碼示例,之后自己的代碼結(jié)構(gòu)可參照這個(gè)結(jié)構(gòu)寫。

代碼邏輯分析

超參數(shù)配置

進(jìn)入tools/train_net.pymain函數(shù),第一行cfg = setup(args)是配置參數(shù)。Detectron2中的參數(shù)配置使用了yacs這個(gè)庫(kù),這個(gè)庫(kù)能夠很好地重用和拼接超參數(shù)文件配置。

我們先看一下detrctron2/config/的文件結(jié)構(gòu):

  • compat.py: 應(yīng)該是對(duì)之前的Detectron庫(kù)的兼容吧,可忽略。三門峽婦科醫(yī)院http://www.0398hfyy.com/

  • config.py: 定義了一個(gè)CfgNode類,這個(gè)類繼承自fvcore庫(kù)(fb寫的一個(gè)共公共庫(kù),提供一些共享的函數(shù),方便各種不同項(xiàng)目使用)中定義的CfgNode,總之就是不斷繼承。。。繼承關(guān)系是這樣的:
    detrctron2.config.CfgNode->fcvore.common.config.CfgNode->yacs.config.CfgNode->dict
    另外該文件還提供了get_cfg()方法,該方法會(huì)返回一個(gè)含有默認(rèn)配置的CfgNode,而這些默認(rèn)的配置值在下面的default.py中定義了,之所以這樣做是因?yàn)橐渲玫哪J(rèn)值太多了,所以為了文檔清晰才寫到了一個(gè)新的文件中去,不過(guò),yacs庫(kù)的作者也建議這樣做。

  • default.py: 如上面所說(shuō),該文件定義了各種參數(shù)的默認(rèn)值。

了解配置函數(shù)的方法后我們?cè)倩氐?code>tools/train_net.py,我們一行一行的來(lái)理解。

  • tools/train_net.py

from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
...

def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg() 
    cfg.merge_from_file(args.config_file) 
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg
  • cfg = get_cfg(): 獲取已經(jīng)配置好默認(rèn)參數(shù)的cfg

  • cfg.merge_from_file(args.config_file):config_file是指定的yaml配置文件,通過(guò)merge_from_file這個(gè)函數(shù)會(huì)將yaml文件中指定的超參數(shù)對(duì)默認(rèn)值進(jìn)行覆蓋。

  • cfg.merge_from_list(args.opts):merge_from_list作用同上面的類似,只不過(guò)是通過(guò)命令行的方式覆蓋。
    例如

opts = ["SYSTEM.NUM_GPUS", 8, "TRAIN.SCALES", "(1, 2, 3, 4)"]
cfg.merge_from_list(opts)
print("cfg\n",cfg)

那么最后會(huì)有

cfg
... (一些默認(rèn)值超參數(shù))
SYSTEM:
	NUM_GPUS: 8
TRAIN:
	SCALES: (1,2,3,4)
  • cfg.freeze(): freeze函數(shù)的作用是將超參數(shù)值凍結(jié),避免被程序不小心修改。

  • default_setup(cfg, args):default_setupdetectron2/engine/default.py中提供的一個(gè)默認(rèn)配置函數(shù),具體是怎么配置的這里不詳細(xì)說(shuō)明了。不過(guò)需要知道的值這個(gè)文件中還提供了很多其他的配置函數(shù),例如還提供了兩個(gè)類:DefaultPredictorDefaultTrainer

Trainer

既然上面提到了DefaultTrainer,那么我們就從這個(gè)類入手了解一下detectron2.engine,其代碼結(jié)構(gòu)如下:

  • train_loop.py: 這個(gè)函數(shù)主要作用是提供了三個(gè)重要的類:

    • register_hooks:這個(gè)很好理解,就是將用戶定義的一些hooks進(jìn)行注冊(cè),說(shuō)大白話就是把若干個(gè)Hook放在一個(gè)list里面去。之后只需要遍歷這個(gè)list依次執(zhí)行就可以了。

    • 第二類其實(shí)就是上面提到的遍歷hook list并執(zhí)行hook,不過(guò)這個(gè)遍歷有四種,分別是before_train,after_train,before_step,after_step。還有一個(gè)就是run_step,這個(gè)函數(shù)其實(shí)就是平常我們?cè)诰帉懹?xùn)練過(guò)程的代碼,例如讀數(shù)據(jù),訓(xùn)練模型,獲取損失值,求導(dǎo)數(shù),反向梯度更新等,只不過(guò)在這個(gè)類里面沒(méi)有定義。

    • 第三類就是train函數(shù),它有兩個(gè)參數(shù),分別是開始的迭代數(shù)和最大的迭代數(shù)。之后就是重復(fù)依次執(zhí)行第二類中的函數(shù)指定迭代次數(shù)。

    • HookBase: 這是一個(gè)Hook的基類,用于指定在訓(xùn)練前后或者每一個(gè)step前后需要做什么事情,所以根據(jù)特定的需求需要對(duì)如下四種方法做不同的定義:before_train,after_train,before_step,after_step。以before_step

    • TrainerBase: 該類中定義的函數(shù)可以歸納成三種:

    • SimpleTrainer:其實(shí)就是繼承自TrainerBase,然后定義了run_step等方法。我們后面也可以繼承這個(gè)類做進(jìn)一步的自定義。

  • defaults.py: 上面已介紹,提供了兩個(gè)類:DefaultPredictorDefaultTrainer,這個(gè)DefaultTrainer就繼承自SimpleTrainer,所以存在如下繼承關(guān)系:
    detectron2.engine.default.DefaultTrainer->detectron2.engine.train_loop.SimpleTrainer->detectron2.engine.train_loop.TrainerBase

  • hooks.py:定義了很多繼承自train_loop.HookBase的Hook。

  • launch.py: 前面提到過(guò),可以理解成代碼啟動(dòng)器,可以根據(jù)命令決定是否采用分布式訓(xùn)練(或者單機(jī)多卡)或者單機(jī)單卡訓(xùn)練。

好了,我們繼續(xù)回到tools/train_net.py的main函數(shù),代碼如下所示。

def main(args):
    cfg = setup(args)

    if args.eval_only:
		...
    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    if cfg.TEST.AUG.ENABLED:
        trainer.register_hooks(
            [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
        )
    return trainer.train()

可以看到下面定義了一個(gè)Trainer,它繼承自detectron2.engine.default.DefaultTrainer,這個(gè)父類會(huì)自動(dòng)解析cfg。之后只需要調(diào)用trainer.train()就可以開始訓(xùn)練了。

感謝各位的閱讀,以上就是“什么是Config和Trainer”的內(nèi)容了,經(jīng)過(guò)本文的學(xué)習(xí)后,相信大家對(duì)什么是Config和Trainer這一問(wèn)題有了更深刻的體會(huì),具體使用情況還需要大家實(shí)踐驗(yàn)證。這里是億速云,小編將為大家推送更多相關(guān)知識(shí)點(diǎn)的文章,歡迎關(guān)注!

向AI問(wèn)一下細(xì)節(jié)

免責(zé)聲明:本站發(fā)布的內(nèi)容(圖片、視頻和文字)以原創(chuàng)、轉(zhuǎn)載和分享為主,文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如果涉及侵權(quán)請(qǐng)聯(lián)系站長(zhǎng)郵箱:is@yisu.com進(jìn)行舉報(bào),并提供相關(guān)證據(jù),一經(jīng)查實(shí),將立刻刪除涉嫌侵權(quán)內(nèi)容。

AI