第4章:支持向量机
4.1 导入必要的库
首先,我们需要导入本章节所需的 Python 库:
python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler
代码解释
让我们详细了解每个导入的库的作用:
- 数学计算库
python
import numpy as np
numpy 是一个强大的数学计算库,就像是我们手中的科学计算器,提供了丰富的数学函数和数组操作功能。
- 数据可视化库
python
import matplotlib.pyplot as plt
matplotlib 是一个强大的绘图库,pyplot 是其中最常用的模块。它就像是我们的数字画笔,可以将数据转化为直观的图形。
- 机器学习相关库
python
from sklearn import svm
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler
- svm:支持向量机(Support Vector Machine)模块,提供了强大的分类算法
- make_blobs:用于生成用于分类的模拟数据的函数
- StandardScaler:用于数据标准化的工具
为什么需要数据标准化?
在机器学习中,数据标准化是一个重要的预处理步骤:
- 消除不同特征之间的量纲差异
- 使得所有特征都在相似的尺度上
- 加快模型的收敛速度
- 提高模型的数值稳定性
就像是在比较不同单位的数据时,我们需要先将它们转换到同一个标准下才能进行比较。
4.2 生成和准备数据
python
# 生成随机数据
X, y = make_blobs(n_samples=100, centers=2, random_state=42)
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
代码解释
- 生成模拟数据
python
X, y = make_blobs(n_samples=100, centers=2, random_state=42)
使用 make_blobs 函数生成用于分类的数据:
- n_samples=100:生成 100 个样本点
- centers=2:分为两个类别
- random_state=42:设置随机种子,确保结果可重现
- 数据标准化
python
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
使用 StandardScaler 对数据进行标准化处理:
- 创建标准化器对象
- fit_transform 方法同时完成了:
- 计算训练数据的均值和标准差
- 使用这些统计量来转换数据
4.3 创建和训练模型
python
# 创建并训练SVM模型
clf = svm.SVC(kernel='linear')
clf.fit(X_scaled, y)
代码解释
- 创建模型
python
clf = svm.SVC(kernel='linear')
创建一个支持向量机分类器:
- kernel='linear':使用线性核函数
- 线性核函数假设数据可以用一个超平面(在二维情况下是一条直线)来分隔
- 训练模型
python
clf.fit(X_scaled, y)
使用标准化后的数据训练模型:
- X_scaled:标准化后的特征数据
- y:类别标签
支持向量机的工作原理
SVM 通过寻找最优分隔超平面来实现分类:
- 最大化分类边界(margin)
- 只关注支持向量(最靠近分类边界的点)
- 可以通过核技巧处理非线性可分的数据
这就像是在两类点之间画一条线,使得:
- 线两边的点尽可能分开
- 线到最近的点的距离尽可能大
4.4 可视化分类结果
python
# 创建网格点
x_min, x_max = X_scaled[:, 0].min() - 1, X_scaled[:, 0].max() + 1
y_min, y_max = X_scaled[:, 1].min() - 1, X_scaled[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
np.arange(y_min, y_max, 0.02))
# 预测网格点的类别
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# 绘制结果
plt.figure(figsize=(10, 8))
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X_scaled[:, 0], X_scaled[:, 1], c=y, alpha=0.8)
plt.title('SVM分类结果')
plt.show()
代码解释
- 创建可视化网格
python
x_min, x_max = X_scaled[:, 0].min() - 1, X_scaled[:, 0].max() + 1
y_min, y_max = X_scaled[:, 1].min() - 1, X_scaled[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
np.arange(y_min, y_max, 0.02))
- 计算数据范围并扩展一个单位
- 创建均匀网格,步长为 0.02
- 网格用于绘制分类边界
- 预测网格点类别
python
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
- 将网格点展平并预测类别
- 将预测结果重塑为网格形状
- 绘制可视化结果
python
plt.figure(figsize=(10, 8))
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X_scaled[:, 0], X_scaled[:, 1], c=y, alpha=0.8)
plt.title('SVM分类结果')
plt.show()
- 创建 10×8 英寸的图形
- 绘制分类边界(半透明)
- 绘制数据点(用不同颜色表示类别)
- 添加标题并显示图形
4.5 总结
支持向量机(SVM)是一种强大的分类算法,它具有以下特点:
优点:
- 在高维空间中依然有效
- 对于数据点的分布没有强假设
- 只使用部分样本点(支持向量)来确定分类边界
- 通过核函数可以处理非线性分类问题
缺点:
- 对数据规模敏感,不适合大规模数据集
- 对参数选择敏感
- 需要进行特征缩放
- 结果不易解释
应用场景:
- 文本分类
- 图像分类
- 生物信息学
- 手写识别
在本章中,我们学习了:
- 如何生成和预处理分类数据
- 如何创建和训练 SVM 模型
- 如何可视化分类结果
SVM 的核心思想是找到最优的分类超平面,就像是在两类数据之间画一条最合适的分界线,使得:
- 不同类别的数据点被清晰地分开
- 分类边界到最近的数据点的距离最大
- 只需要关注对分类起决定性作用的支持向量