溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

關于Pytorch中模型的保存與遷移問題有哪些

發(fā)布時間:2021-10-18 17:07:44 來源:億速云 閱讀:115 作者:iii 欄目:開發(fā)技術

本篇內容介紹了“關于Pytorch中模型的保存與遷移問題有哪些”的有關知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領大家學習一下如何處理這些情況吧!希望大家仔細閱讀,能夠學有所成!

目錄
  • 1 引言

  • 2 模型的保存與復用

    • 2.1 查看網絡模型參數

    • 2.2 載入模型進行推斷

    • 2.3 載入模型進行訓練

    • 2.4 載入模型進行遷移

  • 3 總結

    1 引言

    各位朋友大家好,歡迎來到月來客棧。今天要和大家介紹的內容是如何在Pytorch框架中對模型進行保存和載入、以及模型的遷移和再訓練。一般來說,最常見的場景就是模型完成訓練后的推斷過程。一個網絡模型在完成訓練后通常都需要對新樣本進行預測,此時就只需要構建模型的前向傳播過程,然后載入已訓練好的參數初始化網絡即可。

    第2個場景就是模型的再訓練過程。一個模型在一批數據上訓練完成之后需要將其保存到本地,并且可能過了一段時間后又收集到了一批新的數據,因此這個時候就需要將之前的模型載入進行在新數據上進行增量訓練(或者是在整個數據上進行全量訓練)。

    第3個應用場景就是模型的遷移學習。這個時候就是將別人已經訓練好的預模型拿過來,作為你自己網絡模型參數的一部分進行初始化。例如:你自己在Bert模型的基礎上加了幾個全連接層來做分類任務,那么你就需要將原始BERT模型中的參數載入并以此來初始化你的網絡中的BERT部分的權重參數。

    在接下來的這篇文章中,筆者就以上述3個場景為例來介紹如何利用Pytorch框架來完成上述過程。

    2 模型的保存與復用

    在Pytorch中,我們可以通過torch.save()torch.load()來完成上述場景中的主要步驟。下面,筆者將以之前介紹的LeNet5網絡模型為例來分別進行介紹。不過在這之前,我們先來看看Pytorch中模型參數的保存形式。

    2.1 查看網絡模型參數

    (1)查看參數

    首先定義好LeNet5的網絡模型結構,如下代碼所示:

    class LeNet5(nn.Module):
        def __init__(self, ):
            super(LeNet5, self).__init__()
            self.conv = nn.Sequential(  # [n,1,28,28]
                nn.Conv2d(1, 6, 5, padding=2),  # in_channels, out_channels, kernel_size
                nn.ReLU(),  # [n,6,24,24]
                nn.MaxPool2d(2, 2),  # kernel_size, stride  [n,6,14,14]
                nn.Conv2d(6, 16, 5),  # [n,16,10,10]
                nn.ReLU(),
                nn.MaxPool2d(2, 2))  # [n,16,5,5]
            self.fc = nn.Sequential(
                nn.Flatten(),
                nn.Linear(16 * 5 * 5, 120),
                nn.ReLU(),
                nn.Linear(120, 84),
                nn.ReLU(),
                nn.Linear(84, 10))
        def forward(self, img):
            output = self.conv(img)
            output = self.fc(output)
            return output

    在定義好LeNet5這個網絡結構的類之后,只要我們完成了這個類的實例化操作,那么網絡中對應的權重參數也都完成了初始化的工作,即有了一個初始值。同時,我們可以通過如下方式來訪問:

    # Print model's state_dict
    print("Model's state_dict:")
    for param_tensor in model.state_dict():
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())

    其輸出的結果為:

    conv.0.weight torch.Size([6, 1, 5, 5])
    conv.0.bias torch.Size([6])
    conv.3.weight torch.Size([16, 6, 5, 5])
    ....
    ....

    可以發(fā)現,網絡模型中的參數model.state_dict()其實是以字典的形式(實質上是collections模塊中的OrderedDict)保存下來的:

    print(model.state_dict().keys())
    # odict_keys(['conv.0.weight', 'conv.0.bias', 'conv.3.weight', 'conv.3.bias', 'fc.1.weight', 'fc.1.bias', 'fc.3.weight', 'fc.3.bias', 'fc.5.weight', 'fc.5.bias'])

    (2)自定義參數前綴

    同時,這里值得注意的地方有兩點:①參數名中的fcconv前綴是根據你在上面定義nn.Sequential()時的名字所確定的;②參數名中的數字表示每個Sequential()中網絡層所在的位置。例如將網絡結構定義成如下形式:

    class LeNet5(nn.Module):
        def __init__(self, ):
            super(LeNet5, self).__init__()
            self.moon = nn.Sequential(  # [n,1,28,28]
                nn.Conv2d(1, 6, 5, padding=2),  # in_channels, out_channels, kernel_size
                nn.ReLU(),  # [n,6,24,24]
                nn.MaxPool2d(2, 2),  # kernel_size, stride  [n,6,14,14]
                nn.Conv2d(6, 16, 5),  # [n,16,10,10]
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                nn.Flatten(),
                nn.Linear(16 * 5 * 5, 120),
                nn.ReLU(),
                nn.Linear(120, 84),
                nn.ReLU(),
                nn.Linear(84, 10))

    那么其參數名則為:

    print(model.state_dict().keys())
    odict_keys(['moon.0.weight', 'moon.0.bias', 'moon.3.weight', 'moon.3.bias', 'moon.7.weight', 'moon.7.bias', 'moon.9.weight', 'moon.9.bias', 'moon.11.weight', 'moon.11.bias'])

    理解了這一點對于后續(xù)我們去解析和載入一些預訓練模型很有幫助。

    除此之外,對于中的優(yōu)化器等,其同樣有對應的state_dict()方法來獲取對于的參數,例如:

    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    print("Optimizer's state_dict:")
    for var_name in optimizer.state_dict():
       print(var_name, "\t", optimizer.state_dict()[var_name])
        
    #
    Optimizer's state_dict:
    state   {}
    param_groups   [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140239245300504, 140239208339784, 140239245311360, 140239245310856, 140239266942480, 140239266942552, 140239266942624, 140239266942696, 140239266942912, 140239267041352]}]

    在介紹完模型參數的查看方法后,就可以進入到模型復用階段的內容介紹了。

    2.2 載入模型進行推斷

    (1) 模型保存

    在Pytorch中,對于模型的保存來說是非常簡單的,通常來說通過如下兩行代碼便可以實現:

    model_save_path = os.path.join(model_save_dir, 'model.pt')
    torch.save(model.state_dict(), model_save_path)

    在指定保存的模型名稱時Pytorch官方建議的后綴為.pt或者.pth(當然也不是強制的)。最后,只需要在合適的地方加入第2行代碼即可完成模型的保存。

    同時,如果想要在訓練過程中保存某個條件下的最優(yōu)模型,那么應該通過如下方式:

    best_model_state = deepcopy(model.state_dict()) 
    torch.save(best_model_state, model_save_path)

    而不是:

    best_model_state = model.state_dict() 
    torch.save(best_model_state, model_save_path)

    因為后者best_model_state得到只是model.state_dict()的引用,它依舊會隨著訓練過程而發(fā)生改變。

    (2)復用模型進行推斷

    在推斷過程中,首先需要完成網絡的初始化,然后再載入已有的模型參數來覆蓋網絡中的權重參數即可,示例代碼如下:

    def inference(data_iter, device, model_save_dir='./MODEL'):
        model = LeNet5()  # 初始化現有模型的權重參數
        model.to(device)
        model_save_path = os.path.join(model_save_dir, 'model.pt')
        if os.path.exists(model_save_path):
            loaded_paras = torch.load(model_save_path)
            model.load_state_dict(loaded_paras)  # 用本地已有模型來重新初始化網絡權重參數
            model.eval() # 注意不要忘記
        with torch.no_grad():
            acc_sum, n = 0.0, 0
            for x, y in data_iter:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                acc_sum += (logits.argmax(1) == y).float().sum().item()
                n += len(y)
            print("Accuracy in test data is :", acc_sum / n)

    在上述代碼中,4-7行便是用來載入本地模型參數,并用其覆蓋網絡模型中原有的參數。這樣,便可以進行后續(xù)的推斷工作:

    Accuracy in test data is : 0.8851

    2.3 載入模型進行訓練

    在介紹完模型的保存與復用之后,對于網絡的追加訓練就很簡單了。最簡便的一種方式就是在訓練過程中只保存網絡權重,然后在后續(xù)進行追加訓練時只載入網絡權重參數初始化網絡進行訓練即可,示例如下(完整代碼參見[2]):

     def train(self):
            #......
            model_save_path = os.path.join(self.model_save_dir, 'model.pt')
            if os.path.exists(model_save_path):
                loaded_paras = torch.load(model_save_path)
                self.model.load_state_dict(loaded_paras)
                print("#### 成功載入已有模型,進行追加訓練...")
            optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)  # 定義優(yōu)化器
           #......
            for epoch in range(self.epochs):
                for i, (x, y) in enumerate(train_iter):
                    x, y = x.to(device), y.to(device)
                    logits = self.model(x)
                    # ......
                print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs,
                                                  self.evaluate(test_iter, self.model, device)))
                torch.save(self.model.state_dict(), model_save_path)

    這樣,便完成了模型的追加訓練:

    #### 成功載入已有模型,進行追加訓練...
    Epochs[0/5]---batch[938/0]---acc 0.9062---loss 0.2926
    Epochs[0/5]---batch[938/100]---acc 0.9375---loss 0.1598
    ......

    除此之外,你也可以在保存參數的時候,將優(yōu)化器參數、損失值等一同保存下來,然后在恢復模型的時候連同其它參數一起恢復,示例如下:

    model_save_path = os.path.join(model_save_dir, 'model.pt')
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                ...
                }, model_save_path)

    載入方式如下:

    checkpoint = torch.load(model_save_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    2.4 載入模型進行遷移

    (1)定義新模型

    到目前為止,對于前面兩種應用場景的介紹就算完成了,可以發(fā)現總體上并不復雜。但是對于第3中場景的應用來說就會略微復雜一點。

    假設現在有一個LeNet6網絡模型,它是在LeNet5的基礎最后多加了一個全連接層,其定義如下:

    class LeNet6(nn.Module):
        def __init__(self, ):
            super(LeNet6, self).__init__()
            self.conv = nn.Sequential(  # [n,1,28,28]
                nn.Conv2d(1, 6, 5, padding=2),  # in_channels, out_channels, kernel_size
                nn.ReLU(),  # [n,6,24,24]
                nn.MaxPool2d(2, 2),  # kernel_size, stride  [n,6,14,14]
                nn.Conv2d(6, 16, 5),  # [n,16,10,10]
                nn.ReLU(),
                nn.MaxPool2d(2, 2))  # [n,16,5,5]
            self.fc = nn.Sequential(
                nn.Flatten(),
                nn.Linear(16 * 5 * 5, 120),
                nn.ReLU(),
                nn.Linear(120, 84),
                nn.ReLU(),
                nn.Linear(84, 64), 
                nn.ReLU(),
                nn.Linear(64, 10) ) # 新加入的全連接層

    接下來,我們需要將在LeNet5上訓練得到的權重參數遷移到LeNet6網絡中去。從上面LeNet6的定義可以發(fā)現,此時盡管只是多加了一個全連接層,但是倒數第2層參數的維度也發(fā)生了變換。因此,對于LeNet6來說只能復用LeNet5網絡前面4層的權重參數。

    (2)查看模型參數

    在拿到一個模型參數后,首先我們可以將其載入,然查看相關參數的信息:

    model_save_path = os.path.join('./MODEL', 'model.pt')
    loaded_paras = torch.load(model_save_path)
    for param_tensor in loaded_paras:
        print(param_tensor, "\t", loaded_paras[param_tensor].size())
    
    #---- 可復用部分
    conv.0.weight   torch.Size([6, 1, 5, 5])
    conv.0.bias   torch.Size([6])
    conv.3.weight   torch.Size([16, 6, 5, 5])
    conv.3.bias   torch.Size([16])
    fc.1.weight   torch.Size([120, 400])
    fc.1.bias   torch.Size([120])
    fc.3.weight   torch.Size([84, 120])
    fc.3.bias   torch.Size([84])
    #----- 不可復用部分
    fc.5.weight   torch.Size([10, 84])
    fc.5.bias   torch.Size([10])

    同時,對于LeNet6網絡的參數信息為:

    model = LeNet6()
    for param_tensor in model.state_dict():
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())
    #
    conv.0.weight   torch.Size([6, 1, 5, 5])
    conv.0.bias   torch.Size([6])
    conv.3.weight   torch.Size([16, 6, 5, 5])
    conv.3.bias   torch.Size([16])
    fc.1.weight   torch.Size([120, 400])
    fc.1.bias   torch.Size([120])
    fc.3.weight   torch.Size([84, 120])
    fc.3.bias   torch.Size([84])
    #------ 新加入部分
    fc.5.weight   torch.Size([64, 84])
    fc.5.bias   torch.Size([64])
    fc.7.weight   torch.Size([10, 64])
    fc.7.bias   torch.Size([10])

    在理清楚了新舊模型的參數后,下面就可以將LeNet5中我們需要的參數給取出來,然后再換到LeNet6的網絡中。

    (3)模型遷移

    雖然本地載入的模型參數(上面的loaded_paras)和模型初始化后的參數(上面的model.state_dict())都是一個字典的形式,但是我們并不能夠直接改變model.state_dict()中的權重參數。這里需要先構造一個state_dict然后通過model.load_state_dict()方法來重新初始化網絡中的參數。

    同時,在這個過程中我們需要篩選掉本地模型中不可復用的部分,具體代碼如下:

    def para_state_dict(model, model_save_dir):
        state_dict = deepcopy(model.state_dict())
        model_save_path = os.path.join(model_save_dir, 'model.pt')
        if os.path.exists(model_save_path):
            loaded_paras = torch.load(model_save_path)
            for key in state_dict:  # 在新的網絡模型中遍歷對應參數
                if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size():
                    print("成功初始化參數:", key)
                    state_dict[key] = loaded_paras[key]
        return state_dict

    在上述代碼中,第2行的作用是先拷貝網絡中(LeNet6)原有的參數;第6-9行則是用本地的模型參數(LeNet5)中可以復用的替換掉LeNet6中的對應部分,其中第7行就是判斷可用的條件。同時需要注意的是在不同的情況下篩選的方式可能不一樣,因此具體情況需要具體分析,但是整體邏輯是一樣的。

    最后,我們只需要在模型訓練之前調用該函數,然后重新初始化LeNet6中的部分權重參數即可[2]:

    state_dict = para_state_dict(self.model, self.model_save_dir)
    self.model.load_state_dict(state_dict)

    訓練結果如下:

    成功初始化參數: conv.0.weight
    成功初始化參數: conv.0.bias
    成功初始化參數: conv.3.weight
    成功初始化參數: conv.3.bias
    成功初始化參數: fc.1.weight
    成功初始化參數: fc.1.bias
    成功初始化參數: fc.3.weight
    成功初始化參數: fc.3.bias
    #### 成功載入已有模型,進行追加訓練...
    Epochs[0/5]---batch[938/0]---acc 0.1094---loss 2.512
    Epochs[0/5]---batch[938/100]---acc 0.9375---loss 0.2141
    Epochs[0/5]---batch[938/200]---acc 0.9219---loss 0.2729
    Epochs[0/5]---batch[938/300]---acc 0.8906---loss 0.2958
    ......
    Epochs[0/5]---batch[938/900]---acc 0.8906---loss 0.2828
    Epochs[0/5]--acc on test 0.8808

    可以發(fā)現,在大約100個batch之后,模型的準確率就提升上來了。

    “關于Pytorch中模型的保存與遷移問題有哪些”的內容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關的知識可以關注億速云網站,小編將為大家輸出更多高質量的實用文章!

    向AI問一下細節(jié)

    免責聲明:本站發(fā)布的內容(圖片、視頻和文字)以原創(chuàng)、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

    AI