PyTorch模型的保存与加载
在PyTorch中,保存和加载模型是深度学习工作流程中的关键步骤。通常,保存和加载模型有两种主要方法:保存和加载整个模型或只保存和加载模型的状态字典(state_dict)。以下是这两种方法的详细说明和示例代码:
方法一:保存和加载整个模型
保存模型
你可以使用 torch.save()
函数来保存整个模型,包括模型的结构和参数。
import torch
# 假设 `model` 是你的神经网络模型
torch.save(model, 'model.pth')
加载模型
要加载整个模型,你可以使用 torch.load()
函数,然后调用 eval()
以将模型设为评估模式。
import torch
# 从文件加载模型
model = torch.load('model.pth')
model.eval()
方法二:保存和加载模型的状态字典
这种方法更推荐使用,尤其是当你需要调整或定义与保存模型不同的网络结构时。
保存模型的状态字典
你可以通过调用模型的 state_dict()
方法来获取并保存它的状态字典。
# 保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')
加载模型的状态字典
加载模型的状态字典时,需要先实例化模型对象,然后使用 load_state_dict()
方法加载参数。
import torch
import torch.nn as nn
# 重新创建模型的类实例
model = MyModelClass() # 假设 MyModelClass 是你定义的模型类
# 加载状态字典
model.load_state_dict(torch.load('model_state_dict.pth'))
# 切换到评估模式
model.eval()
使用提示
评估模式:加载模型后调用
model.eval()
将模型设为评估模式,这会关闭诸如 dropout 和 batch normalization 的训练时功能。兼容性:保存整个模型时,确保在加载时的运行环境和 PyTorch 版本与保存时一致,以避免由于版本差异导致的加载错误。而使用状态字典方法时,只需保证模型结构一致即可。
GPU/CPU切换:如果在GPU上保存模型,并想在CPU上加载,或者反过来,则需要在加载时指定
map_location
参数。例如:- 在GPU上保存,在CPU上加载:
model.load_state_dict(torch.load('model_state_dict.pth', map_location=torch.device('cpu')))
保存优化器状态:如果你想继续训练模型,还可以保存和加载优化器的状态字典。
# 保存优化器状态
torch.save(optimizer.state_dict(), 'optimizer_state.pth')
# 加载优化器状态
optimizer.load_state_dict(torch.load('optimizer_state.pth'))
以上是在PyTorch中保存和加载模型的基本方法,选择适合你的工作流程的方式是关键。