当前位置:网站首页 > PyTorch框架 > 正文

pytorch模型部署方案(pytorch模型的保存与加载)



在深度学习的训练过程中,我们不可避免地要保存模型,这是一个非常好的习惯。接下来,文章将通过一个简单的神经网络模型,带你了解 PyTorch 中主要的模型保存与加载方式。

训练一个神经网络可能需要数小时甚至数天的时间,你需要认知到一点:时间是非常宝贵的,目前3090云服务器租赁一天的价格为 37.92 元。如果你的代码没有保存模型的模块,那就先不要开始,因为不保存基本等于没跑,你的效果再好也没有办法直接呈现给别人。如果你保存了模型,你就可以做到以下的事情:

  • 继续训练:通过保存检查点(checkpoint),你可以在意外中断后继续训练你的模型,这一点可能会节省你大量的时间。
  • 模型部署:训练好的模型可以被部署到生产环境中进行推理,比如 LLM,LoRA 等。
  • 分享模型:将训练好的模型分享给实验室其他成员或开源社区,以便进一步研究或复现结果。

为了演示,我们先定义一个简单的神经网络模型:

 

保存模型

 

加载模型

 

输出

 

这种方法非常简单直观,因为它保存了模型的整个结构和参数。

保存模型状态字典

 

加载模型状态字典
需要注意的是,加载state_dict时你需要手动重新实例化模型。

 

输出

 

与保存整个模型相比,保存 更加灵活,它只包含模型的参数,而不依赖于完整的模型定义,这意味着你可以在不同的项目中加载模型参数,甚至只加载部分模型的权重。举个例子,对于分类模型,即便你保存的是完整的网络参数,也可以仅导入特征提取层部分,当然,直接导入完整模型再拆分实际上是一样的。对于不完全匹配的模型,加载时可以通过设置 来忽略某些不匹配的键:

 

这样,你可以灵活地只加载模型的某些部分。

使用 加载模型

假设我们在原来的 模型中新增了一个全连接层(),此时如果我们直接加载之前保存的 ,会因为 中没有 的权重信息而导致报错。

 

输出

 

如果不设置 ,将会报错,提示缺少 的权重:

 

注意,减少层也可以使用 。例如,如果修改后的网络只保留前两层,仍然可以成功加载原始的 ,并跳过缺失的部分。

有时候,你可能不仅仅需要保存模型参数,还需要保存训练进度,比如当前的轮数、优化器状态等。此时可以使用检查点保存更多信息。

保存检查点

 

加载检查点

 

输出:

 

这种方式适合长时间训练时,可以从中断的地方继续训练。但文件体积相比前面会更大,具体原因见《7. 探究模型参数与显存的关系以及不同精度造成的影响》,加载过程也稍微复杂一些,我们可以写一个函数来打包这个过程。

到此这篇pytorch模型部署方案(pytorch模型的保存与加载)的文章就介绍到这了,更多相关内容请继续浏览下面的相关推荐文章,希望大家都能在编程的领域有一番成就!
                            

版权声明


相关文章:

  • pytorch模型部署(pytorch模型部署到web)2024-12-11 11:45:08
  • pytorch模型部署到web(pytorch模型部署到树莓派)2024-12-11 11:45:08
  • pytorch模型部署到web(pytorch模型部署到Linux)2024-12-11 11:45:08
  • 尽情享受生活之乐趣——蒙田2024-12-11 11:45:08
  • 一)pytorch框架与环境搭建_一)pytorch框架与环境搭建2024-12-11 11:45:08
  • pytorch深度学习框架基本介绍_pytorch深度学习框架基本介绍2024-12-11 11:45:08
  • pytorch框架搭建_pytorch框架搭建2024-12-11 11:45:08
  • 人工智能入门 | PyTorch框架介绍和安装步骤_人工智能入门 | PyTorch框架介绍和安装步骤2024-12-11 11:45:08
  • pytorch模型部署onnx(pytorch模型部署 django)2024-12-11 11:45:08
  • 服务器配置pytorch环境(pytorch如何在服务器上跑)2024-12-11 11:45:08
  • 全屏图片