鸢尾花品种分类模型训练

文档说明

本文档详细梳理了鸢尾花(Iris)品种分类模型的全流程,包含数据加载、预处理、模型训练、评估、保存、预测六大核心步骤,所有代码均可直接运行,适合机器学习新手入门学习。

环境准备

1. 安装依赖库

在Anaconda Prompt/终端执行以下命令,安装所需Python库:

pip install pandas scikit-learn joblib

2. 数据集准备

确保iris.data文件与代码文件放在同一文件夹下(若文件路径不同,需在代码中修改路径)。


3.补充些特定环境准备

在anaconda prompt中安装

#查看所有环境
conda info --envs
#创建环境
conda create -n yolov8 python=3.8
#激活环境
conda activate yolov8
#查看python版本
python --version
#安装ipykernel
conda install ipykernel -y

完整流程(分步骤)

步骤1:导入所需库

作用:加载数据处理、模型训练、评估所需的工具库。

# 数据处理库:用于读取和处理结构化数据
import pandas as pd  
# 模型保存/加载库:用于保存训练好的模型,后续可直接复用
import joblib        
# 数据划分工具:将数据集拆分为训练集和测试集
from sklearn.model_selection import train_test_split
# 核心分类模型:逻辑回归(适合小数据集的入门级分类算法)
from sklearn.linear_model import LogisticRegression
# 模型评估指标:计算准确率、生成详细分类报告
from sklearn.metrics import accuracy_score, classification_report

步骤2:加载并探索数据集

作用:读取鸢尾花数据集,了解数据结构和基本信息。

# 定义列名(对应鸢尾花的4个特征 + 1个品种标签)
column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

# 读取iris.data文件(若文件在其他路径,需修改引号内内容,如:'C:/Users/xxx/Desktop/iris.data')
iris_df = pd.read_csv('iris.data', names=column_names)

# 打印数据集前5行,直观查看数据格式
print("数据集前5行:")
print(iris_df.head())

# 打印数据类型、缺失值等基本信息
print("\n数据集基本信息:")
print(iris_df.info())

# 统计各鸢尾花品种的样本数量
print("\n各品种样本数量:")
print(iris_df['species'].value_counts())

关键说明

  • 鸢尾花数据集共150条样本,包含3个品种(Iris-setosa、Iris-versicolor、Iris-virginica);

  • 4个特征均为数值型(花萼长度、花萼宽度、花瓣长度、花瓣宽度),无缺失值,无需额外补全。

步骤3:数据预处理

作用:分离特征和标签,划分训练集/测试集(训练集用于模型学习,测试集用于验证模型效果)。

# 分离特征(X)和标签(y)
# X:4个数值特征(模型输入)
X = iris_df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']]  
# y:品种标签(模型输出,即需要预测的目标)
y = iris_df['species']                                                     

# 划分训练集(80%)和测试集(20%)
# random_state=42:固定随机种子,保证每次划分结果一致(新手无需修改)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 打印划分后的数据量,验证是否正确
print(f"\n训练集样本数:{len(X_train)}, 测试集样本数:{len(X_test)}")

关键说明

  • test_size=0.2:测试集占比20%(30条样本),训练集占比80%(120条样本),是分类任务的常规比例;

  • 划分后需保证训练集/测试集样本数之和等于原数据集(150条)。

步骤4:模型训练

作用:用训练集数据训练逻辑回归模型,让模型学习“特征→品种”的对应关系。

# 初始化逻辑回归模型
# max_iter=200:增大迭代次数,避免因数据未收敛导致的警告(新手无需修改)
model = LogisticRegression(max_iter=200)  

# 核心步骤:用训练集训练模型
model.fit(X_train, y_train)

print("\n模型训练完成!")

关键说明

  • 逻辑回归是线性分类模型,计算速度快,适合鸢尾花这种小数据集;

  • 训练过程无输出,仅需确认无报错即可。

步骤5:模型评估

作用:用测试集验证模型的泛化能力(即模型对新数据的预测效果)。

# 用训练好的模型预测测试集数据
y_pred = model.predict(X_test)

# 计算准确率(所有预测正确的样本占测试集总样本的比例)
accuracy = accuracy_score(y_test, y_pred)
print(f"\n模型准确率:{accuracy:.2f} (1.0表示100%正确)")

# 生成详细分类报告(包含精确率、召回率、F1值,全面评估模型)
print("\n详细分类报告:")
print(classification_report(y_test, y_pred))

关键说明

  • 鸢尾花数据集简单,模型准确率通常≥0.95(甚至1.0),属于正常现象;

  • 分类报告中:

    • 精确率(precision):预测为某品种的样本中,实际是该品种的比例;

    • 召回率(recall):实际为某品种的样本中,被正确预测的比例;

    • F1值:精确率和召回率的综合指标,越接近1越好。

步骤6:模型保存

作用:将训练好的模型保存到本地,后续无需重复训练,直接加载即可使用。

# 定义模型保存路径和文件名(后缀用.joblib或.pkl均可)
model_path = 'iris_classification_model.joblib'

# 保存模型到本地
joblib.dump(model, model_path)

print(f"\n模型已保存到:{model_path}")

关键说明

  • 保存后的.joblib文件包含模型的所有训练参数,是模型的“成品文件”;

  • 该文件可复制、分享,其他环境只需安装相同依赖即可加载使用。

步骤7:加载模型并预测新数据

作用:这是保存模型的核心用途——用训练好的模型对新的鸢尾花数据做品种预测。

# 加载本地保存的模型
loaded_model = joblib.load(model_path)
print("\n模型加载完成,开始预测新数据!")

# 准备新的鸢尾花特征数据(格式必须和训练时一致:[[特征1, 特征2, 特征3, 特征4]])
# 示例数据:可替换成你自己的鸢尾花尺寸
new_iris_data = [
    [5.1, 3.5, 1.4, 0.2],  # 预期预测结果:Iris-setosa
    [6.2, 2.9, 4.3, 1.3],  # 预期预测结果:Iris-versicolor
    [7.3, 2.9, 6.3, 1.8]   # 预期预测结果:Iris-virginica
]

# 用加载的模型预测新数据
predictions = loaded_model.predict(new_iris_data)

# 输出预测结果
print("\n新数据预测结果:")
for i, data in enumerate(new_iris_data):
    print(f"特征:{data} → 预测品种:{predictions[i]}")

关键说明

  • 新数据必须是二维列表(如[[5.1,3.5,1.4,0.2]]),不能是一维列表([5.1,3.5,1.4,0.2]),否则会报错;

  • 特征顺序必须和训练时一致(花萼长→花萼宽→花瓣长→花瓣宽)。


常见问题解决

  1. 文件找不到报错

    • 检查iris.data文件路径是否正确,或直接使用绝对路径(如'C:/Users/xxx/Desktop/iris.data');
  2. 模型收敛警告

    • 初始化模型时添加max_iter=200(如LogisticRegression(max_iter=200));
  3. 预测时报错“维度不匹配”

    • 确保新数据是二维列表,例如将[5.1,3.5,1.4,0.2]改为[[5.1,3.5,1.4,0.2]]

核心总结

  1. 鸢尾花模型训练的标准流程:加载数据→预处理→训练→评估→保存→预测,适用于所有结构化数据的分类任务;

  2. 逻辑回归是入门级线性分类模型,适合小数据集,训练速度极快;

  3. .joblib文件是模型的“成品”,核心用途是加载后直接预测新数据,无需重复训练。