使用pcolormesh绘制鸢尾花的分类图

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# 参考 https://www.cnblogs.com/shenxiaolin/p/8857158.html , https://blog.csdn.net/liulina603/article/details/78676723
from sklearn.linear_model import LogisticRegression #导入逻辑回归模型
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
#多次分类问题
# 1.花萼的长、宽
# 2.花瓣的长、宽
# 3.花萼长、花瓣宽
# 4.花瓣长、花萼宽

'''
datas
labels
'''
def logistic(datas, labels, titleOfChart, x_label, y_label):
plt.figure()#创建新画布
datas=np.array(datas)
datas=datas.astype(float) #将二维的字符串数组转换成浮点数数组
labels=np.array(labels) #150个类别,50个'Iris-setosa',50个'Iris-versicolor',50个'Iris-virginica'
clf = LogisticRegression() #获取Logistic线性回归模型
#将字符串类型的标签转化为数字类型,进行数字编码处理
for i in range(len(labels)):
if labels[i] == 'Iris-setosa':
labels[i] = 0
elif labels[i] == 'Iris-versicolor':
labels[i] = 1
else:
labels[i] = 2
labels = labels.astype(np.int)#降低内存空间

#模型建立
clf.fit(datas, labels)#使用已有的数据和标签进行模型拟合

#数据准备阶段
N, M = 500, 500 # 横纵各采样多少个值
x1_min, x2_min = datas.min(axis=0)
x1_max, x2_max = datas.max(axis=0)
t1 = np.linspace(x1_min, x1_max, N)
t2 = np.linspace(x2_min, x2_max, M)
x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点
x_show = np.stack((x1.flat, x2.flat), axis=1) # 测试点
y_predict = clf.predict(x_show)#预测值

cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])
plt.pcolormesh(x1, x2, y_predict.reshape(x1.shape), shading='auto', cmap=cm_light)
plt.scatter(datas[:,0], datas[:,1], c=labels.transpose(), cmap=cm_dark, marker='o', edgecolors='k')

plt.xlim(x1_min, x1_max)
plt.ylim(x2_min, x2_max)
#设置x轴和y轴的标签以及主题
plt.title(titleOfChart, fontsize=15)
plt.xlabel(x_label, fontsize=11)
plt.ylabel(y_label, fontsize=11)

plt.show()


# 设置中文字体
mpl.rcParams["font.sans-serif"]=["SimHei"]
mpl.rcParams["axes.unicode_minus"]=False

cm_light = mpl.colors.ListedColormap(['g', 'r', 'b'])
attributes=['SepalLength','SepalWidth','PetalLength','PetalWidth'] #鸢尾花的四个属性名
#按照不同的分类标准收集数据
datas1=[]
datas2=[]
datas3=[]
datas4=[]
labels1=[]#标签列
data_file=open('iris.txt','r')
#数据初始化
for line in data_file.readlines():
linedata = line.split(',')
#剔除异常数据
if len(linedata) == 5:
linedata = np.array(linedata) #转化为多维数组
datas1.append(linedata[:2]) # [0:2]花萼长宽
labels1.append(linedata[-1].replace('\n', '')) # 获取150朵花儿所属的类别
datas2.append(linedata[2:4]) # [2:4]花瓣长宽
datas3.append(linedata[:][0:4:3]) # [0,3]花萼长,花瓣宽
datas4.append(linedata[1:3]) # [1:3]花萼宽,花瓣长

data_file.close()

logistic(datas1,labels1, "花萼长,花萼宽", "花萼长/(cm)", "花萼宽/(cm)")
logistic(datas2,labels1, "花瓣长,花瓣宽", "花瓣长/(cm)", "花瓣宽/(cm)")
logistic(datas3,labels1, "花萼长,花瓣宽", "花萼长/(cm)", "花瓣宽/(cm)")
logistic(datas4,labels1, "花萼宽,花瓣长", "花萼宽/(cm)", "花瓣长/(cm)")

其中iris.txt的文件内容如下图所示。

1

没有列名,我们使用open打开文件,文件中使用了“,”进行分割,首先明确每行有5条数据,分别是花萼长、花萼宽、花瓣长、花瓣宽、鸢尾花的类别。

因为我们要对鸢尾花做分类,我们本次任务是从多个属性组合来检查分类效果。

1.花萼的长、宽

2.花瓣的长、宽

3.花萼长、花瓣宽

4.花瓣长、花萼宽

所以我们分别统计这些数据,然后放入Logistic线性回归模型中去训练,然后使用pcolormesh来实现最终的分类效果展示。

分类图的绘画

训练完模型后,现在需要画出分类边界,首先需要在横纵坐标各取500点,一共组成250000个点,然后把这250000个点送进我们训练好的Logistic模型,来算出所属的种类,代码如下:

1
2
3
4
5
6
7
8
9
#数据准备阶段
N, M = 500, 500 # 横纵各采样多少个值
x1_min, x2_min = datas.min(axis=0)
x1_max, x2_max = datas.max(axis=0)
t1 = np.linspace(x1_min, x1_max, N)
t2 = np.linspace(x2_min, x2_max, M)
x1, x2 = np.meshgrid(t1, t2) # 生成网格采样点
x_show = np.stack((x1.flat, x2.flat), axis=1) # 测试点
y_predict = clf.predict(x_show)#预测值

接着就可以绘制出分类图了。由于该数据集中一共有三种鸢尾花,所以绘制图片的时候需要三种颜色

1
2
cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF'])
cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b'])

接着使用plt.pcolormesh来绘制分类图,这里请注意,如果不写入shading='auto',代码会发出警告,不会影响运行。

1
2
plt.pcolormesh(x1, x2, y_predict.reshape(x1.shape), shading='auto', cmap=cm_light)
plt.scatter(datas[:,0], datas[:,1], c=labels.transpose(), cmap=cm_dark, marker='o', edgecolors='k')

plt.pcolormesh()会根据y_predict的结果自动在cmap里选择颜色。

2


使用pcolormesh绘制鸢尾花的分类图
https://fulequn.github.io/2020/10/Article202010172/
作者
Fulequn
发布于
2020年10月17日
许可协议