PyTorch教程-7:PyTorch中保存与加载tensor和模型详解

笔者PyTorch的全部简单教程请访问:https://www.jianshu.com/nb/48831659

PyTorch教程-7:PyTorch中保存与加载tensor和模型详解

保存和读取Tensor

PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save()方法保存张量,使用torch.load()来读取张量:

x = torch.rand(4,5)
torch.save(x, "./myTensor.pt")

y = torch.load("./myTensor.pt")
print(y)

tensor([[0.9363, 0.2292, 0.1612, 0.9558, 0.9414],
        [0.3649, 0.9622, 0.3547, 0.5772, 0.7575],
        [0.7005, 0.8115, 0.6132, 0.6640, 0.1173],
        [0.6999, 0.1023, 0.8544, 0.7708, 0.1254]])

当然,saveload方法也适用于其他数据类型,比如list、tuple、dict等:

a = {'a':torch.rand(2,2), 'b':torch.rand(3,4)}
torch.save(a, "./myDict.pth")

b = torch.load("./myDict.pth")
print(b)

{'a': tensor([[0.9356, 0.0240],
        [0.6004, 0.3923]]), 'b': tensor([[0.0222, 0.1799, 0.9172, 0.8159],
        [0.3749, 0.6689, 0.4796, 0.5772],
        [0.5016, 0.5279, 0.5109, 0.0592]])}

保存Tensor的纯数据

PyTorch中,使用 torch.save 保存的不仅有其中的数据,还包括一些它的信息,包括它与其它数据(可能存在)的关系,这一点是很有趣的。

This is an implementation detail that may change in the future, but it typically saves space and lets PyTorch easily reconstruct the view relationships between the loaded tensors.

详细的原文可以参考:https://pytorch.org/docs/stable/notes/serialization.html#saving-and-loading-tensors-preserves-views

这里结合例子给出一个简单的解释。

x = torch.arange(20)
y = x[:5]

torch.save([x,y], "./myTensor.pth")
x_, y_ = torch.load("././myTensor.pth")

y_ *= 100

print(x_)

tensor([  0, 100, 200, 300, 400,   5,   6,   7,   8,   9,  10,  11,  12,  13, 14,  15,  16,  17,  18,  19])

比如在上边的例子中,我们看到yx的一个前五位的切片,当我们同时保存xy后,它们的切片关系也被保存了下来,再将他们加载出来,它们之间依然保留着这个关系,因此可以看到,我们将加载出来的 y_ 乘以100后,x_ 也跟着变化了。

如果不想保留他们的关系,其实也很简单,再保存y之前使用 clone 方法保存一个只有数据的“克隆体”,这样就能只保存数据而不保留关系:

x = torch.arange(20)
y = x[:5]

torch.save([x,y.clone()], "./myTensor.pth")
x_, y_ = torch.load("././myTensor.pth")

y_ *= 100

print(x_)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

当我们只保存y而不同时保存x会怎样呢?这样的话确实可以避免如上的情况,即不会再在读取数据后保留他们的关系,但是实际上有一个不容易被看到的影响存在,那就是保存的数据所占用的空间会和其“父亲”级别的数据一样大

x = torch.arange(1000)
y = x[:5]

torch.save(y, "./myTensor1.pth")
torch.save(y.clone(), "./myTensor2.pth")

y1_ = torch.load("./myTensor1.pth")
y2_ = torch.load("./myTensor2.pth")

print(y1_.storage().size())
print(y2_.storage().size())

1000
5

如果你去观察他们保存的文件,会发现占用的空间确实存在很大的差距:

myTensor1.pth      9KB
myTensor2.pth      1KB

综上所述,对于一些“被关系”的数据来说,如果不想保留他们的关系,最好使用 clone 来保存其“纯数据”

保存与加载模型

保存与加载state_dict

这是一种较为推荐的保存方法,即只保存模型的参数,保存的模型文件会较小,而且比较灵活。但是当加载时,需要先实例化一个模型,然后通过加载将参数赋给这个模型的实例,也就是说加载之前使用者需要知道模型的结构

  • 保存:
    torch.save(model.state_dict(), PATH)
    
  • 加载:
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
    

比较重要的点是:

  • 保存模型时调用 state_dict() 获取模型的参数,而不保存结构
  • 加载模型时需要预先实例化一个对应的网络,比如net=MyNet(),这也就意味着,使用者需要预先有MyNet这个类,如果他/她不知道这个网络的类定义或者结构,这种只保存参数的方法将无法使用
  • 加载模型使用 load_state_dict 方法,其参数不是文件路径,而是 torch.load(PATH)
  • 如果加载出来的模型用于验证,不要忘了使用 model.eval() 方法,它会丢弃 dropout、normalization 等层,因为这些层不能在inference的时候使用,否则得到的推断结果不一致。

一个例子:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        # convolution layers
        self.conv1 = nn.Conv2d(1,6,3)
        self.conv2 = nn.Conv2d(6,16,3)

        # fully-connection layers
        self.fc1 = nn.Linear(16*6*6,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        # max pooling over convolution layers
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)

        # fully-connected layers followed by activation functions
        x = x.view(-1,16*6*6)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        # final fully-connected without activation functon
        x = self.fc3(x)

        return x

net = Net()

torch.save(net.state_dict(), "./myModel.pth")

loaded_net = Net()
loaded_net.load_state_dict(torch.load("./myModel.pth"))
loaded_net.eval()

Net(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

保存与加载整个模型

这种方式不仅保存、加载模型的数据,也包括模型的结构一并存储,存储的文件会较大,好处是加载时不需要提前知道模型的结构,解来即用。实际上这与上文提到的保存Tensor是一致的。

  • 保存:
    torch.save(model, PATH)
    
  • 加载:
    model = torch.load(PATH)
    model.eval()
    

同样的,如果加载的模型用于inference,则需要使用 model.eval()

保存与加载模型与其他信息

有时我们不仅要保存模型,还要连带保存一些其他的信息。比如在训练过程中保存一些 checkpoint,往往除了模型,还要保存它的epoch、loss、optimizer等信息,以便于加载后对这些 checkpoint 继续训练等操作;或者再比如,有时候需要将多个模型一起打包保存等。这些其实也很简单,正如我们上文提到的,torch.save 可以保存dict、list、tuple等多种数据结构,所以一个字典可以很完美的解决这个问题,比如一个简单的例子:

# saving
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

# loading
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

跨设备存储与加载

跨设备的情况指对于一些数据的保存、加载在不同的设备上,比如一个在CPU上,一个在GPU上的情况,大致可以分为如下几种情况:

从CPU保存,加载到CPU

实际上,这就是默认的情况,我们上文提到的所有内容都没有关心设备的问题,因此也就适应于这种情况。

从CPU保存,加载到GPU

  • 保存:依旧使用默认的方法
  • 加载:有两种可选的方式
    • 使用 torch.load() 函数的 map_location 参数指定加载后的数据保存的设备
    • 对于加载后的模型使用 to() 函数发送到设备
torch.save(net.state_dict(), PATH)

device = torch.device("cuda")

loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))
# or
loaded_net.to(device)

从GPU保存,加载到CPU

  • 保存:依旧使用默认的方法
  • 加载:只能使用 torch.load() 函数的 map_location 参数指定加载后的数据保存的设备
torch.save(net.state_dict(), PATH)

device = torch.device("cuda")

loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))

从GPU保存,加载到GPU

  • 保存:依旧使用默认的方法
  • 加载:只能使用 对于加载后的模型进行 to() 函数发送到设备
torch.save(net.state_dict(), PATH)

device = torch.device("cuda")

loaded_net = Net()
loaded_net.to(device)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 161,192评论 4 369
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 68,186评论 1 303
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 110,844评论 0 252
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,471评论 0 217
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,876评论 3 294
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,891评论 1 224
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 32,068评论 2 317
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,791评论 0 205
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,539评论 1 249
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,772评论 2 253
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,250评论 1 265
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,577评论 3 260
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,244评论 3 241
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,146评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,949评论 0 201
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,995评论 2 285
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,812评论 2 276

推荐阅读更多精彩内容