提示信息

PyTorch模型的保存与加载

水晶叶子 已关注

在PyTorch中,保存和加载模型是深度学习工作流程中的关键步骤。通常,保存和加载模型有两种主要方法:保存和加载整个模型或只保存和加载模型的状态字典(state_dict)。以下是这两种方法的详细说明和示例代码:

方法一:保存和加载整个模型

保存模型

你可以使用 torch.save() 函数来保存整个模型,包括模型的结构和参数。

  1. import torch
  2. # 假设 `model` 是你的神经网络模型
  3. torch.save(model, 'model.pth')

加载模型

要加载整个模型,你可以使用 torch.load() 函数,然后调用 eval() 以将模型设为评估模式。

  1. import torch
  2. # 从文件加载模型
  3. model = torch.load('model.pth')
  4. model.eval()

方法二:保存和加载模型的状态字典

这种方法更推荐使用,尤其是当你需要调整或定义与保存模型不同的网络结构时。

保存模型的状态字典

你可以通过调用模型的 state_dict() 方法来获取并保存它的状态字典。

  1. # 保存模型的状态字典
  2. torch.save(model.state_dict(), 'model_state_dict.pth')

加载模型的状态字典

加载模型的状态字典时,需要先实例化模型对象,然后使用 load_state_dict() 方法加载参数。

  1. import torch
  2. import torch.nn as nn
  3. # 重新创建模型的类实例
  4. model = MyModelClass() # 假设 MyModelClass 是你定义的模型类
  5. # 加载状态字典
  6. model.load_state_dict(torch.load('model_state_dict.pth'))
  7. # 切换到评估模式
  8. model.eval()

使用提示

  1. 评估模式:加载模型后调用 model.eval() 将模型设为评估模式,这会关闭诸如 dropout 和 batch normalization 的训练时功能。

  2. 兼容性:保存整个模型时,确保在加载时的运行环境和 PyTorch 版本与保存时一致,以避免由于版本差异导致的加载错误。而使用状态字典方法时,只需保证模型结构一致即可。

  3. GPU/CPU切换:如果在GPU上保存模型,并想在CPU上加载,或者反过来,则需要在加载时指定 map_location 参数。例如:

    • 在GPU上保存,在CPU上加载:
    1. model.load_state_dict(torch.load('model_state_dict.pth', map_location=torch.device('cpu')))
  4. 保存优化器状态:如果你想继续训练模型,还可以保存和加载优化器的状态字典。

    1. # 保存优化器状态
    2. torch.save(optimizer.state_dict(), 'optimizer_state.pth')
    3. # 加载优化器状态
    4. optimizer.load_state_dict(torch.load('optimizer_state.pth'))

以上是在PyTorch中保存和加载模型的基本方法,选择适合你的工作流程的方式是关键。

    遇到难题? "AI大模型GPT4.0、GPT" 是你的私人解答专家! 点击按钮去提问......
水晶叶子 关注 已关注

最近一次登录:2024-11-20 11:52:49   

暂时还没有签名,请关注我或评论我的文章
×
免费图表工具,画流程图、架构图