模型持久化
(模型保存与加载)是机器学习完成的最后一步。
因为,在实际情况中,训练一个模型可能会非常耗时,如果每次需要使用模型时都要重新训练,这无疑会浪费大量的计算资源和时间。

通过将训练好的模型持久化到磁盘,我们可以在需要使用模型时直接从磁盘加载到内存,而无需重新训练。这样不仅可以节省时间,还可以提高模型的使用效率。

本篇介绍
scikit-learn
中几种常用的模型持久化方法。

1. 训练模型

首先,训练一个模型,这里用
scikit-learn
自带的
手写数字数据集
作为样本。

import matplotlib.pyplot as plt
from sklearn import datasets

# 加载手写数据集
data = datasets.load_digits()

# 调整数据格式
n_samples = len(data.images)
X = data.images.reshape((n_samples, -1))
y = data.target

# 用支持向量机训练模型
from sklearn.svm import SVC

# 定义
reg = SVC()

# 训练模型
reg.fit(X, y)

最后的得到的
reg
就是我们训练之后的模型,使用这个模型,就可以预测一些手写数字图片。

但是这个
reg
是代码中的一个变量,如果不能保存下来,那么,每次需要使用的时候,
还要重新执行一次上面的模型训练代码,样本数据量大的话,每次重复训练会浪费大量时间和计算资源。

所以,要将上面的
reg
模型保存下来,下次使用的时候,直接加载,不用重新训练。

2. 模型持久化

2.1. pickle 序列化

pickle
格式是
python
中常用的序列化方式,它通过将
python对象
及其所拥有的层次结构转化为一个字节流来实现序列化。

将上面的模型保存到磁盘文件
model.pkl
中。

import pickle

with open("./model.pkl", "wb") as f:
    pickle.dump(reg, f)

需要使用模型时,从磁盘加载的方式:

with open("./model.pkl", "rb") as f:
    reg_pkl = pickle.load(f)

验证加载之后的模型
reg_pkl
是否可以正常使用。

y_pred = reg_pkl.predict(X)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()

plt.show()

image.png
从混淆矩阵来看,模型可以正常加载和使用。
关于混淆矩阵具体内容,可以参考:
【scikit-learn基础】--『分类模型评估』之评估报告

2.2. joblib 序列化

相比于
pickle
,保存机器学习模型时,更推荐使用
joblib

因为
joblib
针对大数据进行了优化,使其在处理大型数据集时性能更佳。

序列化的方式也很简单:

import joblib

joblib.dump(reg, "model.jlib")

从磁盘加载模型并验证:

reg_jlib = joblib.load("model.jlib")

y_pred = reg_jlib.predict(X)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()

plt.show()

image.png

2.3. skops 格式

skops是比较新的一种格式,它是专门为了共享基于
scikit-learn
的模型而开发的。
目前还在积极的开发中,github上的地址是:
github-skops

相比于
pickle

joblib
,它提供了更加安全的序列化格式,
但使用上和它们差别不大。

import skops.io as sio

# 保存到文件 model.sio
sio.dump(reg, "model.sio")

从文件中读取模型并验证:

reg_sio = sio.load("model.sio")

y_pred = reg_jlib.predict(X)

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()

plt.show()

image.png

3. 总结


scikit-learn
中,模型持久化是一个重要且实用的技术,它允许我们将训练好的模型保存到磁盘上,以便在不同的时间点或不同的环境中重新加载和使用。
通过模型持久化,我们能够避免每次需要使用时重新训练模型,从而节省大量的时间和计算资源。

本篇介绍的三种方法可以方便的序列化和反序列化模型对象,使其可以轻松地保存到磁盘上,并能够在需要时恢复出原始模型对象。

总而言之,模型持久化不仅使得我们能够在不同的运行会话之间重用模型,还方便了模型的共享和部署。

标签: none

添加新评论