求知 文章 文库 Lib 视频 iPerson 课程 认证 咨询 工具 讲座 Modeler   Code  
会员   
 


业务架构设计
4月18-19日 在线直播



基于UML和EA进行系统分析设计
4月25-26日 北京+在线



AI 智能化软件测试方法与实践
5月23-24日 上海+在线
 
追随技术信仰

随时听讲座
每天看新闻
 
 
Sklearn 教程
1. Sklearn简介
2. Sklearn安装
3. Sklearn 基础概念
4. Sklearn 数据预处理
5. Sklearn 机器学习模型
6. Sklearn 模型评估与调优
7. Sklearn 管道(Pipeline)
8. Sklearn 自定义模型与功能
9. Sklearn 模型保存与加载
9. Sklearn 应用案例
 

 
目录
Sklearn 模型保存与加载
57 次浏览
1次  

在机器学习中,模型的训练过程通常是耗时的,为了避免每次重新训练模型,我们可以将训练好的模型保存下来,便于以后进行加载和预测。

scikit-learn 提供了两种常用的方式来保存和加载模型:joblib 和 pickle。

1、使用 joblib 保存与加载模型

joblib 是一个高效的 Python 序列化工具,特别适合用于保存包含大量数值数组(如 numpy 数组、scikit-learn 模型等)的对象。相较于 pickle,joblib 在处理大规模数据时更高效。

joblib 是 Python 的一个外部库,可以通过以下命令安装:

pip install joblib

保存模型

joblib 提供了一个简单的 API 来保存和加载对象。

我们可以使用 joblib.dump() 方法将模型保存到文件中。

实例

import joblib
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

# 加载数据
data = load_iris()
X, y = data.data, data.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建并训练模型
model = SVC(kernel='linear')
model.fit(X_train, y_train)

# 保存模型到文件
joblib.dump(model, 'svm_model.joblib')

加载模型

使用 joblib.load() 方法加载保存的模型对象。

实例

# 加载保存的模型
loaded_model = joblib.load('svm_model.joblib')

# 使用加载的模型进行预测
y_pred = loaded_model.predict(X_test)

# 打印预测结果
print("Predictions:", y_pred)

通过上述步骤,我们成功地将训练好的模型保存到文件中,并在之后的任何时间加载该模型并进行预测。

2、使用 pickle 保存与加载模型

pickle 是 Python 内置的模块,允许将 Python 对象序列化和反序列化。

虽然 joblib 更适用于处理大量数据,但 pickle 也是常用的保存和加载模型的工具,适用于一般情况。

保存模型

与 joblib 类似,pickle 也有简单的 API 来保存和加载对象。

保存模型的代码如下:

实例

import pickle
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

# 加载数据
data = load_iris()
X, y = data.data, data.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建并训练模型
model = SVC(kernel='linear')
model.fit(X_train, y_train)

# 使用 pickle 保存模型
with open('svm_model.pkl', 'wb') as f:
    pickle.dump(model, f)

加载模型

使用 pickle.load() 加载模型:

实例

# 使用 pickle 加载保存的模型
with open('svm_model.pkl', 'rb') as f:
    loaded_model = pickle.load(f)

# 使用加载的模型进行预测
y_pred = loaded_model.predict(X_test)

# 打印预测结果
print("Predictions:", y_pred)

3、joblib vs pickle

joblib 和 pickle 是保存和加载模型的两种常用方法。

joblib 更适合保存大型数据对象,而 pickle 是 Python 的标准序列化工具,适用于一般情况。

  • joblib:通常适用于保存包含大量数值数据(如 numpy 数组)的对象。joblib 在处理大规模数据时比 pickle 更高效。

  • pickle:适用于保存较小的对象或常规的 Python 对象。它是 Python 的内置库,使用时无需额外安装。

如果模型中包含大量数值数组或矩阵(如支持向量机、随机森林等),推荐使用 joblib,它比 pickle 更高效。对于较小的模型或不包含大量数值数据的模型,pickle 足够使用。

4、保存和加载管道(Pipeline)

在实际应用中,模型不仅仅是单一的模型,有时会结合多个处理步骤(如数据预处理、特征选择、模型训练等),这些处理步骤可以使用 scikit-learn 的 Pipeline 来完成。Pipeline 也可以通过 joblib 或 pickle 保存和加载。

保存管道:

实例

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
import joblib

# 创建一个管道
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('svc', SVC(kernel='linear'))
])

# 训练管道
pipeline.fit(X_train, y_train)

# 保存管道到文件
joblib.dump(pipeline, 'pipeline_model.joblib')

加载管道:

实例

# 加载管道
loaded_pipeline = joblib.load('pipeline_model.joblib')

# 使用加载的管道进行预测
y_pred = loaded_pipeline.predict(X_test)

# 打印预测结果
print("Predictions:", y_pred)

管道保存和加载的过程与单一模型相同,只需确保保存和加载整个管道对象即可。

5、模型版本管理

在机器学习的实际应用中,模型的更新和版本管理至关重要。每次训练模型并保存时,最好为模型文件命名加上时间戳或版本号,以便区分不同版本的模型。例如:

实例

import time

# 创建时间戳
timestamp = time.strftime("%Y%m%d-%H%M%S")

# 保存带时间戳的模型
joblib.dump(model, f'svm_model_{timestamp}.joblib')

这样,我们可以根据时间戳来管理不同版本的模型,便于模型的回溯和更新。

6、使用模型进行持久化

一旦模型训练完成并保存,我们可以在后续的实际应用中加载该模型来进行预测,而无需重新训练。

例如,我们可以将保存的模型与 Web 服务、批处理作业或其他应用程序集成,使得模型可以反复使用,而无需重新训练。

Web 服务中使用加载的模型

例如,假设我们正在使用 Flask 创建一个简单的 Web 服务,通过 API 接口提供模型预测服务。在这种情况下,我们可以加载保存的模型进行实时预测。

实例

from flask import Flask, request, jsonify
import joblib
import numpy as np

app = Flask(__name__)

# 加载模型
model = joblib.load('svm_model.joblib')

@app.route('/predict', methods=['POST'])
def predict():
    data = request.get_json()  # 获取输入数据
    features = np.array(data['features']).reshape(1, -1)  # 转换成适合预测的格式
    prediction = model.predict(features)  # 使用加载的模型进行预测
    return jsonify({'prediction': prediction.tolist()})  # 返回预测结果

if __name__ == '__main__':
    app.run(debug=True)

您可以捐助,支持我们的公益事业。

1元 10元 50元





认证码: 验证码,看不清楚?请点击刷新验证码 必填



57 次浏览
1次