Python机器学习

在Python中实现SVM

在Python中实现SVM详细操作教程
要在Python中实现SVM,我们将从如下所示的标准库导入开始-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns; sns.set()
接下来,我们从sklearn.dataset.sample_generator创建一个具有线性可分离数据的样本数据集,以便使用SVM进行分类-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
from sklearn.datasets.samples_generator import make_blobs
X, y = make_blobs(n_samples = 100, centers = 2, random_state = 0, cluster_std = 0.50)
plt.scatter(X[:, 0], X[:, 1], c = y, s = 50, cmap = 'summer');
以下是生成具有100个样本和2个聚类的样本数据集后的输出-
在Python中实现SVM
我们知道SVM支持判别分类。它通过在二维的情况下简单地找到一条线,在多维的情况下通过歧管来简单地将类彼此划分。它在上面的数据集上实现如下-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
xfit = np.linspace(-1, 3.5)
plt.scatter(X[:, 0], X[:, 1], c = y, s = 50, cmap = 'summer')
plt.plot([0.6], [2.1], 'x', color = 'black', markeredgewidth = 4, markersize = 12)
for m, b in [(1, 0.65), (0.5, 1.6), (-0.2, 2.9)]:
plt.plot(xfit, m * xfit + b, '-k')
plt.xlim(-1, 3.5);
输出如下-
输出
从上面的输出中我们可以看到,有三种不同的分隔符可以完美地区分以上示例。
如前所述,SVM的主要目标是将数据集划分为类,以找到最大的边际超平面(MMH),因此,我们不必在类之间绘制零线,而可以在每条线周围画出一定宽度的边距,直至最近的点。可以完成以下操作-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
xfit = np.linspace(-1, 3.5)
plt.scatter(X[:, 0], X[:, 1], c = y, s = 50, cmap = 'summer')
for m, b, d in [(1, 0.65, 0.33), (0.5, 1.6, 0.55), (-0.2, 2.9, 0.2)]:
   yfit = m * xfit + b
   plt.plot(xfit, yfit, '-k')
   plt.fill_between(xfit, yfit - d, yfit + d, edgecolor='none',
   color = '#AAAAAA', alpha = 0.4)
plt.xlim(-1, 3.5);
最大边缘超平面
从上面的输出图像中,我们可以轻松地观察到判别式分类器中的"边距"。SVM将选择使边距最大化的线。
接下来,我们将使用Scikit-Learn的支持向量分类器在此数据上训练SVM模型。在这里,我们使用线性核来拟合SVM,如下所示:-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
from sklearn.svm import SVC # "Support vector classifier"
model = SVC(kernel = 'linear', C = 1E10)
model.fit(X, y)
输出如下-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
SVC(C=10000000000.0, cache_size=200, class_weight=None, coef0=0.0,
decision_function_shape='ovr', degree=3, gamma='auto_deprecated',
kernel='linear', max_iter=-1, probability=False, random_state=None,
shrinking=True, tol=0.001, verbose=False)
现在,为了更好地理解,下面将绘制2D SVC的决策函数-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
def decision_function(model, ax = None, plot_support = True):
   if ax is None:
      ax = plt.gca()
   xlim = ax.get_xlim()
   ylim = ax.get_ylim()
对于评估模型,我们需要如下创建网格-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
x = np.linspace(xlim[0], xlim[1], 30)
y = np.linspace(ylim[0], ylim[1], 30)
Y, X = np.meshgrid(y, x)
xy = np.vstack([X.ravel(), Y.ravel()]).T
P = model.decision_function(xy).reshape(X.shape)
接下来,我们需要绘制决策边界和边际,如下所示:-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
ax.contour(X, Y, P, colors = 'k', levels = [-1, 0, 1], alpha = 0.5, linestyles = ['--', '-', '--'])
现在,类似地绘制支持向量,如下所示:-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
if plot_support:
   ax.scatter(model.support_vectors_[:, 0],
   model.support_vectors_[:, 1], s = 300, linewidth = 1, facecolors = 'none');
ax.set_xlim(xlim)
ax.set_ylim(ylim)
现在,使用此功能如下拟合我们的模型-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
plt.scatter(X[:, 0], X[:, 1], c = y, s = 50, cmap = 'summer')
decision_function(model);
实施SVM模型
我们可以从上面的输出中观察到SVM分类器适合数据的边距,即虚线和支持向量,该适合度的关键元素与虚线接触。这些支持向量点存储在分类器的 support_vectors _属性中,如下所示-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
model.support_vectors_
输出如下-
# Filename : example.py
# Copyright : 2020 By Lidihuo
# Author by : www.lidihuo.com
# Date : 2020-08-27
array([[0.5323772 , 3.31338909], [2.11114739, 3.57660449], [1.46870582, 1.86947425]])
昵称: 邮箱:
Copyright © 2022 立地货 All Rights Reserved.
备案号:京ICP备14037608号-4