In [8]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScaler
from sklearn.naive_bayes import GaussianNB, MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
def iris_type(s):
it = {'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}
return it[s]
In [15]:
data = np.loadtxt('D:\mlInAction\8.iris.data', encoding='utf-8', dtype=float, delimiter=',',
converters={4: iris_type})
data
Out[15]:
array([[5.1, 3.5, 1.4, 0.2, 0. ], [4.9, 3. , 1.4, 0.2, 0. ], [4.7, 3.2, 1.3, 0.2, 0. ], [4.6, 3.1, 1.5, 0.2, 0. ], [5. , 3.6, 1.4, 0.2, 0. ], [5.4, 3.9, 1.7, 0.4, 0. ], [4.6, 3.4, 1.4, 0.3, 0. ], [5. , 3.4, 1.5, 0.2, 0. ], [4.4, 2.9, 1.4, 0.2, 0. ], [4.9, 3.1, 1.5, 0.1, 0. ], [5.4, 3.7, 1.5, 0.2, 0. ], [4.8, 3.4, 1.6, 0.2, 0. ], [4.8, 3. , 1.4, 0.1, 0. ], [4.3, 3. , 1.1, 0.1, 0. ], [5.8, 4. , 1.2, 0.2, 0. ], [5.7, 4.4, 1.5, 0.4, 0. ], [5.4, 3.9, 1.3, 0.4, 0. ], [5.1, 3.5, 1.4, 0.3, 0. ], [5.7, 3.8, 1.7, 0.3, 0. ], [5.1, 3.8, 1.5, 0.3, 0. ], [5.4, 3.4, 1.7, 0.2, 0. ], [5.1, 3.7, 1.5, 0.4, 0. ], [4.6, 3.6, 1. , 0.2, 0. ], [5.1, 3.3, 1.7, 0.5, 0. ], [4.8, 3.4, 1.9, 0.2, 0. ], [5. , 3. , 1.6, 0.2, 0. ], [5. , 3.4, 1.6, 0.4, 0. ], [5.2, 3.5, 1.5, 0.2, 0. ], [5.2, 3.4, 1.4, 0.2, 0. ], [4.7, 3.2, 1.6, 0.2, 0. ], [4.8, 3.1, 1.6, 0.2, 0. ], [5.4, 3.4, 1.5, 0.4, 0. ], [5.2, 4.1, 1.5, 0.1, 0. ], [5.5, 4.2, 1.4, 0.2, 0. ], [4.9, 3.1, 1.5, 0.1, 0. ], [5. , 3.2, 1.2, 0.2, 0. ], [5.5, 3.5, 1.3, 0.2, 0. ], [4.9, 3.1, 1.5, 0.1, 0. ], [4.4, 3. , 1.3, 0.2, 0. ], [5.1, 3.4, 1.5, 0.2, 0. ], [5. , 3.5, 1.3, 0.3, 0. ], [4.5, 2.3, 1.3, 0.3, 0. ], [4.4, 3.2, 1.3, 0.2, 0. ], [5. , 3.5, 1.6, 0.6, 0. ], [5.1, 3.8, 1.9, 0.4, 0. ], [4.8, 3. , 1.4, 0.3, 0. ], [5.1, 3.8, 1.6, 0.2, 0. ], [4.6, 3.2, 1.4, 0.2, 0. ], [5.3, 3.7, 1.5, 0.2, 0. ], [5. , 3.3, 1.4, 0.2, 0. ], [7. , 3.2, 4.7, 1.4, 1. ], [6.4, 3.2, 4.5, 1.5, 1. ], [6.9, 3.1, 4.9, 1.5, 1. ], [5.5, 2.3, 4. , 1.3, 1. ], [6.5, 2.8, 4.6, 1.5, 1. ], [5.7, 2.8, 4.5, 1.3, 1. ], [6.3, 3.3, 4.7, 1.6, 1. ], [4.9, 2.4, 3.3, 1. , 1. ], [6.6, 2.9, 4.6, 1.3, 1. ], [5.2, 2.7, 3.9, 1.4, 1. ], [5. , 2. , 3.5, 1. , 1. ], [5.9, 3. , 4.2, 1.5, 1. ], [6. , 2.2, 4. , 1. , 1. ], [6.1, 2.9, 4.7, 1.4, 1. ], [5.6, 2.9, 3.6, 1.3, 1. ], [6.7, 3.1, 4.4, 1.4, 1. ], [5.6, 3. , 4.5, 1.5, 1. ], [5.8, 2.7, 4.1, 1. , 1. ], [6.2, 2.2, 4.5, 1.5, 1. ], [5.6, 2.5, 3.9, 1.1, 1. ], [5.9, 3.2, 4.8, 1.8, 1. ], [6.1, 2.8, 4. , 1.3, 1. ], [6.3, 2.5, 4.9, 1.5, 1. ], [6.1, 2.8, 4.7, 1.2, 1. ], [6.4, 2.9, 4.3, 1.3, 1. ], [6.6, 3. , 4.4, 1.4, 1. ], [6.8, 2.8, 4.8, 1.4, 1. ], [6.7, 3. , 5. , 1.7, 1. ], [6. , 2.9, 4.5, 1.5, 1. ], [5.7, 2.6, 3.5, 1. , 1. ], [5.5, 2.4, 3.8, 1.1, 1. ], [5.5, 2.4, 3.7, 1. , 1. ], [5.8, 2.7, 3.9, 1.2, 1. ], [6. , 2.7, 5.1, 1.6, 1. ], [5.4, 3. , 4.5, 1.5, 1. ], [6. , 3.4, 4.5, 1.6, 1. ], [6.7, 3.1, 4.7, 1.5, 1. ], [6.3, 2.3, 4.4, 1.3, 1. ], [5.6, 3. , 4.1, 1.3, 1. ], [5.5, 2.5, 4. , 1.3, 1. ], [5.5, 2.6, 4.4, 1.2, 1. ], [6.1, 3. , 4.6, 1.4, 1. ], [5.8, 2.6, 4. , 1.2, 1. ], [5. , 2.3, 3.3, 1. , 1. ], [5.6, 2.7, 4.2, 1.3, 1. ], [5.7, 3. , 4.2, 1.2, 1. ], [5.7, 2.9, 4.2, 1.3, 1. ], [6.2, 2.9, 4.3, 1.3, 1. ], [5.1, 2.5, 3. , 1.1, 1. ], [5.7, 2.8, 4.1, 1.3, 1. ], [6.3, 3.3, 6. , 2.5, 2. ], [5.8, 2.7, 5.1, 1.9, 2. ], [7.1, 3. , 5.9, 2.1, 2. ], [6.3, 2.9, 5.6, 1.8, 2. ], [6.5, 3. , 5.8, 2.2, 2. ], [7.6, 3. , 6.6, 2.1, 2. ], [4.9, 2.5, 4.5, 1.7, 2. ], [7.3, 2.9, 6.3, 1.8, 2. ], [6.7, 2.5, 5.8, 1.8, 2. ], [7.2, 3.6, 6.1, 2.5, 2. ], [6.5, 3.2, 5.1, 2. , 2. ], [6.4, 2.7, 5.3, 1.9, 2. ], [6.8, 3. , 5.5, 2.1, 2. ], [5.7, 2.5, 5. , 2. , 2. ], [5.8, 2.8, 5.1, 2.4, 2. ], [6.4, 3.2, 5.3, 2.3, 2. ], [6.5, 3. , 5.5, 1.8, 2. ], [7.7, 3.8, 6.7, 2.2, 2. ], [7.7, 2.6, 6.9, 2.3, 2. ], [6. , 2.2, 5. , 1.5, 2. ], [6.9, 3.2, 5.7, 2.3, 2. ], [5.6, 2.8, 4.9, 2. , 2. ], [7.7, 2.8, 6.7, 2. , 2. ], [6.3, 2.7, 4.9, 1.8, 2. ], [6.7, 3.3, 5.7, 2.1, 2. ], [7.2, 3.2, 6. , 1.8, 2. ], [6.2, 2.8, 4.8, 1.8, 2. ], [6.1, 3. , 4.9, 1.8, 2. ], [6.4, 2.8, 5.6, 2.1, 2. ], [7.2, 3. , 5.8, 1.6, 2. ], [7.4, 2.8, 6.1, 1.9, 2. ], [7.9, 3.8, 6.4, 2. , 2. ], [6.4, 2.8, 5.6, 2.2, 2. ], [6.3, 2.8, 5.1, 1.5, 2. ], [6.1, 2.6, 5.6, 1.4, 2. ], [7.7, 3. , 6.1, 2.3, 2. ], [6.3, 3.4, 5.6, 2.4, 2. ], [6.4, 3.1, 5.5, 1.8, 2. ], [6. , 3. , 4.8, 1.8, 2. ], [6.9, 3.1, 5.4, 2.1, 2. ], [6.7, 3.1, 5.6, 2.4, 2. ], [6.9, 3.1, 5.1, 2.3, 2. ], [5.8, 2.7, 5.1, 1.9, 2. ], [6.8, 3.2, 5.9, 2.3, 2. ], [6.7, 3.3, 5.7, 2.5, 2. ], [6.7, 3. , 5.2, 2.3, 2. ], [6.3, 2.5, 5. , 1.9, 2. ], [6.5, 3. , 5.2, 2. , 2. ], [6.2, 3.4, 5.4, 2.3, 2. ], [5.9, 3. , 5.1, 1.8, 2. ]])
In [16]:
x, y = np.split(data, (4,), axis=1) # 前四列是x最后一列是y
x
Out[16]:
array([[5.1, 3.5, 1.4, 0.2], [4.9, 3. , 1.4, 0.2], [4.7, 3.2, 1.3, 0.2], [4.6, 3.1, 1.5, 0.2], [5. , 3.6, 1.4, 0.2], [5.4, 3.9, 1.7, 0.4], [4.6, 3.4, 1.4, 0.3], [5. , 3.4, 1.5, 0.2], [4.4, 2.9, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5.4, 3.7, 1.5, 0.2], [4.8, 3.4, 1.6, 0.2], [4.8, 3. , 1.4, 0.1], [4.3, 3. , 1.1, 0.1], [5.8, 4. , 1.2, 0.2], [5.7, 4.4, 1.5, 0.4], [5.4, 3.9, 1.3, 0.4], [5.1, 3.5, 1.4, 0.3], [5.7, 3.8, 1.7, 0.3], [5.1, 3.8, 1.5, 0.3], [5.4, 3.4, 1.7, 0.2], [5.1, 3.7, 1.5, 0.4], [4.6, 3.6, 1. , 0.2], [5.1, 3.3, 1.7, 0.5], [4.8, 3.4, 1.9, 0.2], [5. , 3. , 1.6, 0.2], [5. , 3.4, 1.6, 0.4], [5.2, 3.5, 1.5, 0.2], [5.2, 3.4, 1.4, 0.2], [4.7, 3.2, 1.6, 0.2], [4.8, 3.1, 1.6, 0.2], [5.4, 3.4, 1.5, 0.4], [5.2, 4.1, 1.5, 0.1], [5.5, 4.2, 1.4, 0.2], [4.9, 3.1, 1.5, 0.1], [5. , 3.2, 1.2, 0.2], [5.5, 3.5, 1.3, 0.2], [4.9, 3.1, 1.5, 0.1], [4.4, 3. , 1.3, 0.2], [5.1, 3.4, 1.5, 0.2], [5. , 3.5, 1.3, 0.3], [4.5, 2.3, 1.3, 0.3], [4.4, 3.2, 1.3, 0.2], [5. , 3.5, 1.6, 0.6], [5.1, 3.8, 1.9, 0.4], [4.8, 3. , 1.4, 0.3], [5.1, 3.8, 1.6, 0.2], [4.6, 3.2, 1.4, 0.2], [5.3, 3.7, 1.5, 0.2], [5. , 3.3, 1.4, 0.2], [7. , 3.2, 4.7, 1.4], [6.4, 3.2, 4.5, 1.5], [6.9, 3.1, 4.9, 1.5], [5.5, 2.3, 4. , 1.3], [6.5, 2.8, 4.6, 1.5], [5.7, 2.8, 4.5, 1.3], [6.3, 3.3, 4.7, 1.6], [4.9, 2.4, 3.3, 1. ], [6.6, 2.9, 4.6, 1.3], [5.2, 2.7, 3.9, 1.4], [5. , 2. , 3.5, 1. ], [5.9, 3. , 4.2, 1.5], [6. , 2.2, 4. , 1. ], [6.1, 2.9, 4.7, 1.4], [5.6, 2.9, 3.6, 1.3], [6.7, 3.1, 4.4, 1.4], [5.6, 3. , 4.5, 1.5], [5.8, 2.7, 4.1, 1. ], [6.2, 2.2, 4.5, 1.5], [5.6, 2.5, 3.9, 1.1], [5.9, 3.2, 4.8, 1.8], [6.1, 2.8, 4. , 1.3], [6.3, 2.5, 4.9, 1.5], [6.1, 2.8, 4.7, 1.2], [6.4, 2.9, 4.3, 1.3], [6.6, 3. , 4.4, 1.4], [6.8, 2.8, 4.8, 1.4], [6.7, 3. , 5. , 1.7], [6. , 2.9, 4.5, 1.5], [5.7, 2.6, 3.5, 1. ], [5.5, 2.4, 3.8, 1.1], [5.5, 2.4, 3.7, 1. ], [5.8, 2.7, 3.9, 1.2], [6. , 2.7, 5.1, 1.6], [5.4, 3. , 4.5, 1.5], [6. , 3.4, 4.5, 1.6], [6.7, 3.1, 4.7, 1.5], [6.3, 2.3, 4.4, 1.3], [5.6, 3. , 4.1, 1.3], [5.5, 2.5, 4. , 1.3], [5.5, 2.6, 4.4, 1.2], [6.1, 3. , 4.6, 1.4], [5.8, 2.6, 4. , 1.2], [5. , 2.3, 3.3, 1. ], [5.6, 2.7, 4.2, 1.3], [5.7, 3. , 4.2, 1.2], [5.7, 2.9, 4.2, 1.3], [6.2, 2.9, 4.3, 1.3], [5.1, 2.5, 3. , 1.1], [5.7, 2.8, 4.1, 1.3], [6.3, 3.3, 6. , 2.5], [5.8, 2.7, 5.1, 1.9], [7.1, 3. , 5.9, 2.1], [6.3, 2.9, 5.6, 1.8], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [4.9, 2.5, 4.5, 1.7], [7.3, 2.9, 6.3, 1.8], [6.7, 2.5, 5.8, 1.8], [7.2, 3.6, 6.1, 2.5], [6.5, 3.2, 5.1, 2. ], [6.4, 2.7, 5.3, 1.9], [6.8, 3. , 5.5, 2.1], [5.7, 2.5, 5. , 2. ], [5.8, 2.8, 5.1, 2.4], [6.4, 3.2, 5.3, 2.3], [6.5, 3. , 5.5, 1.8], [7.7, 3.8, 6.7, 2.2], [7.7, 2.6, 6.9, 2.3], [6. , 2.2, 5. , 1.5], [6.9, 3.2, 5.7, 2.3], [5.6, 2.8, 4.9, 2. ], [7.7, 2.8, 6.7, 2. ], [6.3, 2.7, 4.9, 1.8], [6.7, 3.3, 5.7, 2.1], [7.2, 3.2, 6. , 1.8], [6.2, 2.8, 4.8, 1.8], [6.1, 3. , 4.9, 1.8], [6.4, 2.8, 5.6, 2.1], [7.2, 3. , 5.8, 1.6], [7.4, 2.8, 6.1, 1.9], [7.9, 3.8, 6.4, 2. ], [6.4, 2.8, 5.6, 2.2], [6.3, 2.8, 5.1, 1.5], [6.1, 2.6, 5.6, 1.4], [7.7, 3. , 6.1, 2.3], [6.3, 3.4, 5.6, 2.4], [6.4, 3.1, 5.5, 1.8], [6. , 3. , 4.8, 1.8], [6.9, 3.1, 5.4, 2.1], [6.7, 3.1, 5.6, 2.4], [6.9, 3.1, 5.1, 2.3], [5.8, 2.7, 5.1, 1.9], [6.8, 3.2, 5.9, 2.3], [6.7, 3.3, 5.7, 2.5], [6.7, 3. , 5.2, 2.3], [6.3, 2.5, 5. , 1.9], [6.5, 3. , 5.2, 2. ], [6.2, 3.4, 5.4, 2.3], [5.9, 3. , 5.1, 1.8]])
In [17]:
y
Out[17]:
array([[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.], [2.]])
In [18]:
x = x[:, :2] # 只取前两列作为x
x
Out[18]:
array([[5.1, 3.5], [4.9, 3. ], [4.7, 3.2], [4.6, 3.1], [5. , 3.6], [5.4, 3.9], [4.6, 3.4], [5. , 3.4], [4.4, 2.9], [4.9, 3.1], [5.4, 3.7], [4.8, 3.4], [4.8, 3. ], [4.3, 3. ], [5.8, 4. ], [5.7, 4.4], [5.4, 3.9], [5.1, 3.5], [5.7, 3.8], [5.1, 3.8], [5.4, 3.4], [5.1, 3.7], [4.6, 3.6], [5.1, 3.3], [4.8, 3.4], [5. , 3. ], [5. , 3.4], [5.2, 3.5], [5.2, 3.4], [4.7, 3.2], [4.8, 3.1], [5.4, 3.4], [5.2, 4.1], [5.5, 4.2], [4.9, 3.1], [5. , 3.2], [5.5, 3.5], [4.9, 3.1], [4.4, 3. ], [5.1, 3.4], [5. , 3.5], [4.5, 2.3], [4.4, 3.2], [5. , 3.5], [5.1, 3.8], [4.8, 3. ], [5.1, 3.8], [4.6, 3.2], [5.3, 3.7], [5. , 3.3], [7. , 3.2], [6.4, 3.2], [6.9, 3.1], [5.5, 2.3], [6.5, 2.8], [5.7, 2.8], [6.3, 3.3], [4.9, 2.4], [6.6, 2.9], [5.2, 2.7], [5. , 2. ], [5.9, 3. ], [6. , 2.2], [6.1, 2.9], [5.6, 2.9], [6.7, 3.1], [5.6, 3. ], [5.8, 2.7], [6.2, 2.2], [5.6, 2.5], [5.9, 3.2], [6.1, 2.8], [6.3, 2.5], [6.1, 2.8], [6.4, 2.9], [6.6, 3. ], [6.8, 2.8], [6.7, 3. ], [6. , 2.9], [5.7, 2.6], [5.5, 2.4], [5.5, 2.4], [5.8, 2.7], [6. , 2.7], [5.4, 3. ], [6. , 3.4], [6.7, 3.1], [6.3, 2.3], [5.6, 3. ], [5.5, 2.5], [5.5, 2.6], [6.1, 3. ], [5.8, 2.6], [5. , 2.3], [5.6, 2.7], [5.7, 3. ], [5.7, 2.9], [6.2, 2.9], [5.1, 2.5], [5.7, 2.8], [6.3, 3.3], [5.8, 2.7], [7.1, 3. ], [6.3, 2.9], [6.5, 3. ], [7.6, 3. ], [4.9, 2.5], [7.3, 2.9], [6.7, 2.5], [7.2, 3.6], [6.5, 3.2], [6.4, 2.7], [6.8, 3. ], [5.7, 2.5], [5.8, 2.8], [6.4, 3.2], [6.5, 3. ], [7.7, 3.8], [7.7, 2.6], [6. , 2.2], [6.9, 3.2], [5.6, 2.8], [7.7, 2.8], [6.3, 2.7], [6.7, 3.3], [7.2, 3.2], [6.2, 2.8], [6.1, 3. ], [6.4, 2.8], [7.2, 3. ], [7.4, 2.8], [7.9, 3.8], [6.4, 2.8], [6.3, 2.8], [6.1, 2.6], [7.7, 3. ], [6.3, 3.4], [6.4, 3.1], [6. , 3. ], [6.9, 3.1], [6.7, 3.1], [6.9, 3.1], [5.8, 2.7], [6.8, 3.2], [6.7, 3.3], [6.7, 3. ], [6.3, 2.5], [6.5, 3. ], [6.2, 3.4], [5.9, 3. ]])
In [19]:
gnb = Pipeline([
('sc', StandardScaler()), # 把数据进行高斯标准化,以0为均值,1为方差
('clf', GaussianNB())]) # 假定数据为高斯分布
In [20]:
y.ravel() # 转化为行向量
Out[20]:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])
In [21]:
gnb.fit(x, y.ravel())
Out[21]:
Pipeline(memory=None, steps=[('sc', StandardScaler(copy=True, with_mean=True, with_std=True)), ('clf', GaussianNB(priors=None, var_smoothing=1e-09))])
In [23]:
y_hat = gnb.predict(x)
y_hat
Out[23]:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 2., 2., 2., 1., 2., 1., 2., 1., 2., 1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 1., 2., 2., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 1., 2., 1., 2., 2., 1., 2., 2., 2., 2., 1., 2., 1., 1., 2., 2., 2., 2., 1., 2., 1., 2., 1., 2., 2., 1., 1., 1., 2., 2., 2., 1., 1., 1., 2., 2., 2., 1., 2., 2., 2., 1., 2., 2., 2., 1., 2., 2., 1.])
In [24]:
y = y.reshape(-1) # 相当于y.ravel()
y
Out[24]:
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.])
In [25]:
result = y_hat == y
result
Out[25]:
array([ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, False, False, False, True, False, True, False, True, False, True, True, True, True, True, True, False, True, True, True, True, True, True, True, True, False, False, False, False, True, True, True, True, True, True, True, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, False, True, True, False, True, True, True, True, False, True, False, False, True, True, True, True, False, True, False, True, False, True, True, False, False, False, True, True, True, False, False, False, True, True, True, False, True, True, True, False, True, True, True, False, True, True, False])
In [27]:
acc = np.mean(result) # 相当于把true当成1,false为0,求平均值,即为准确率
acc
Out[27]:
0.78
以下为版本2
In [29]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScaler, MinMaxScaler, PolynomialFeatures
from sklearn.naive_bayes import GaussianNB, MultinomialNB #高斯贝叶斯和多项式朴素贝叶斯
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 设置属性防止中文乱码
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
# 花萼长度、花萼宽度,花瓣长度,花瓣宽度
iris_feature_E = 'sepal length', 'sepal width', 'petal length', 'petal width'
iris_feature_C = u'花萼长度', u'花萼宽度', u'花瓣长度', u'花瓣宽度'
iris_class = 'Iris-setosa', 'Iris-versicolor', 'Iris-virginica'
features = [2, 3]
# 读取数据
path = 'D:\mlInAction\8.iris.data' # 数据文件路径
data = pd.read_csv(path, header=None)
data
Out[29]:
0 | 1 | 2 | 3 | 4 | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | Iris-setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | Iris-setosa |
5 | 5.4 | 3.9 | 1.7 | 0.4 | Iris-setosa |
6 | 4.6 | 3.4 | 1.4 | 0.3 | Iris-setosa |
7 | 5.0 | 3.4 | 1.5 | 0.2 | Iris-setosa |
8 | 4.4 | 2.9 | 1.4 | 0.2 | Iris-setosa |
9 | 4.9 | 3.1 | 1.5 | 0.1 | Iris-setosa |
10 | 5.4 | 3.7 | 1.5 | 0.2 | Iris-setosa |
11 | 4.8 | 3.4 | 1.6 | 0.2 | Iris-setosa |
12 | 4.8 | 3.0 | 1.4 | 0.1 | Iris-setosa |
13 | 4.3 | 3.0 | 1.1 | 0.1 | Iris-setosa |
14 | 5.8 | 4.0 | 1.2 | 0.2 | Iris-setosa |
15 | 5.7 | 4.4 | 1.5 | 0.4 | Iris-setosa |
16 | 5.4 | 3.9 | 1.3 | 0.4 | Iris-setosa |
17 | 5.1 | 3.5 | 1.4 | 0.3 | Iris-setosa |
18 | 5.7 | 3.8 | 1.7 | 0.3 | Iris-setosa |
19 | 5.1 | 3.8 | 1.5 | 0.3 | Iris-setosa |
20 | 5.4 | 3.4 | 1.7 | 0.2 | Iris-setosa |
21 | 5.1 | 3.7 | 1.5 | 0.4 | Iris-setosa |
22 | 4.6 | 3.6 | 1.0 | 0.2 | Iris-setosa |
23 | 5.1 | 3.3 | 1.7 | 0.5 | Iris-setosa |
24 | 4.8 | 3.4 | 1.9 | 0.2 | Iris-setosa |
25 | 5.0 | 3.0 | 1.6 | 0.2 | Iris-setosa |
26 | 5.0 | 3.4 | 1.6 | 0.4 | Iris-setosa |
27 | 5.2 | 3.5 | 1.5 | 0.2 | Iris-setosa |
28 | 5.2 | 3.4 | 1.4 | 0.2 | Iris-setosa |
29 | 4.7 | 3.2 | 1.6 | 0.2 | Iris-setosa |
... | ... | ... | ... | ... | ... |
120 | 6.9 | 3.2 | 5.7 | 2.3 | Iris-virginica |
121 | 5.6 | 2.8 | 4.9 | 2.0 | Iris-virginica |
122 | 7.7 | 2.8 | 6.7 | 2.0 | Iris-virginica |
123 | 6.3 | 2.7 | 4.9 | 1.8 | Iris-virginica |
124 | 6.7 | 3.3 | 5.7 | 2.1 | Iris-virginica |
125 | 7.2 | 3.2 | 6.0 | 1.8 | Iris-virginica |
126 | 6.2 | 2.8 | 4.8 | 1.8 | Iris-virginica |
127 | 6.1 | 3.0 | 4.9 | 1.8 | Iris-virginica |
128 | 6.4 | 2.8 | 5.6 | 2.1 | Iris-virginica |
129 | 7.2 | 3.0 | 5.8 | 1.6 | Iris-virginica |
130 | 7.4 | 2.8 | 6.1 | 1.9 | Iris-virginica |
131 | 7.9 | 3.8 | 6.4 | 2.0 | Iris-virginica |
132 | 6.4 | 2.8 | 5.6 | 2.2 | Iris-virginica |
133 | 6.3 | 2.8 | 5.1 | 1.5 | Iris-virginica |
134 | 6.1 | 2.6 | 5.6 | 1.4 | Iris-virginica |
135 | 7.7 | 3.0 | 6.1 | 2.3 | Iris-virginica |
136 | 6.3 | 3.4 | 5.6 | 2.4 | Iris-virginica |
137 | 6.4 | 3.1 | 5.5 | 1.8 | Iris-virginica |
138 | 6.0 | 3.0 | 4.8 | 1.8 | Iris-virginica |
139 | 6.9 | 3.1 | 5.4 | 2.1 | Iris-virginica |
140 | 6.7 | 3.1 | 5.6 | 2.4 | Iris-virginica |
141 | 6.9 | 3.1 | 5.1 | 2.3 | Iris-virginica |
142 | 5.8 | 2.7 | 5.1 | 1.9 | Iris-virginica |
143 | 6.8 | 3.2 | 5.9 | 2.3 | Iris-virginica |
144 | 6.7 | 3.3 | 5.7 | 2.5 | Iris-virginica |
145 | 6.7 | 3.0 | 5.2 | 2.3 | Iris-virginica |
146 | 6.3 | 2.5 | 5.0 | 1.9 | Iris-virginica |
147 | 6.5 | 3.0 | 5.2 | 2.0 | Iris-virginica |
148 | 6.2 | 3.4 | 5.4 | 2.3 | Iris-virginica |
149 | 5.9 | 3.0 | 5.1 | 1.8 | Iris-virginica |
150 rows × 5 columns
In [35]:
x = data[list(range(4))] # 此处为pd,不能用切片
x
Out[35]:
0 | 1 | 2 | 3 | |
---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 |
1 | 4.9 | 3.0 | 1.4 | 0.2 |
2 | 4.7 | 3.2 | 1.3 | 0.2 |
3 | 4.6 | 3.1 | 1.5 | 0.2 |
4 | 5.0 | 3.6 | 1.4 | 0.2 |
5 | 5.4 | 3.9 | 1.7 | 0.4 |
6 | 4.6 | 3.4 | 1.4 | 0.3 |
7 | 5.0 | 3.4 | 1.5 | 0.2 |
8 | 4.4 | 2.9 | 1.4 | 0.2 |
9 | 4.9 | 3.1 | 1.5 | 0.1 |
10 | 5.4 | 3.7 | 1.5 | 0.2 |
11 | 4.8 | 3.4 | 1.6 | 0.2 |
12 | 4.8 | 3.0 | 1.4 | 0.1 |
13 | 4.3 | 3.0 | 1.1 | 0.1 |
14 | 5.8 | 4.0 | 1.2 | 0.2 |
15 | 5.7 | 4.4 | 1.5 | 0.4 |
16 | 5.4 | 3.9 | 1.3 | 0.4 |
17 | 5.1 | 3.5 | 1.4 | 0.3 |
18 | 5.7 | 3.8 | 1.7 | 0.3 |
19 | 5.1 | 3.8 | 1.5 | 0.3 |
20 | 5.4 | 3.4 | 1.7 | 0.2 |
21 | 5.1 | 3.7 | 1.5 | 0.4 |
22 | 4.6 | 3.6 | 1.0 | 0.2 |
23 | 5.1 | 3.3 | 1.7 | 0.5 |
24 | 4.8 | 3.4 | 1.9 | 0.2 |
25 | 5.0 | 3.0 | 1.6 | 0.2 |
26 | 5.0 | 3.4 | 1.6 | 0.4 |
27 | 5.2 | 3.5 | 1.5 | 0.2 |
28 | 5.2 | 3.4 | 1.4 | 0.2 |
29 | 4.7 | 3.2 | 1.6 | 0.2 |
... | ... | ... | ... | ... |
120 | 6.9 | 3.2 | 5.7 | 2.3 |
121 | 5.6 | 2.8 | 4.9 | 2.0 |
122 | 7.7 | 2.8 | 6.7 | 2.0 |
123 | 6.3 | 2.7 | 4.9 | 1.8 |
124 | 6.7 | 3.3 | 5.7 | 2.1 |
125 | 7.2 | 3.2 | 6.0 | 1.8 |
126 | 6.2 | 2.8 | 4.8 | 1.8 |
127 | 6.1 | 3.0 | 4.9 | 1.8 |
128 | 6.4 | 2.8 | 5.6 | 2.1 |
129 | 7.2 | 3.0 | 5.8 | 1.6 |
130 | 7.4 | 2.8 | 6.1 | 1.9 |
131 | 7.9 | 3.8 | 6.4 | 2.0 |
132 | 6.4 | 2.8 | 5.6 | 2.2 |
133 | 6.3 | 2.8 | 5.1 | 1.5 |
134 | 6.1 | 2.6 | 5.6 | 1.4 |
135 | 7.7 | 3.0 | 6.1 | 2.3 |
136 | 6.3 | 3.4 | 5.6 | 2.4 |
137 | 6.4 | 3.1 | 5.5 | 1.8 |
138 | 6.0 | 3.0 | 4.8 | 1.8 |
139 | 6.9 | 3.1 | 5.4 | 2.1 |
140 | 6.7 | 3.1 | 5.6 | 2.4 |
141 | 6.9 | 3.1 | 5.1 | 2.3 |
142 | 5.8 | 2.7 | 5.1 | 1.9 |
143 | 6.8 | 3.2 | 5.9 | 2.3 |
144 | 6.7 | 3.3 | 5.7 | 2.5 |
145 | 6.7 | 3.0 | 5.2 | 2.3 |
146 | 6.3 | 2.5 | 5.0 | 1.9 |
147 | 6.5 | 3.0 | 5.2 | 2.0 |
148 | 6.2 | 3.4 | 5.4 | 2.3 |
149 | 5.9 | 3.0 | 5.1 | 1.8 |
150 rows × 4 columns
In [36]:
x = x[features]
x
Out[36]:
2 | 3 | |
---|---|---|
0 | 1.4 | 0.2 |
1 | 1.4 | 0.2 |
2 | 1.3 | 0.2 |
3 | 1.5 | 0.2 |
4 | 1.4 | 0.2 |
5 | 1.7 | 0.4 |
6 | 1.4 | 0.3 |
7 | 1.5 | 0.2 |
8 | 1.4 | 0.2 |
9 | 1.5 | 0.1 |
10 | 1.5 | 0.2 |
11 | 1.6 | 0.2 |
12 | 1.4 | 0.1 |
13 | 1.1 | 0.1 |
14 | 1.2 | 0.2 |
15 | 1.5 | 0.4 |
16 | 1.3 | 0.4 |
17 | 1.4 | 0.3 |
18 | 1.7 | 0.3 |
19 | 1.5 | 0.3 |
20 | 1.7 | 0.2 |
21 | 1.5 | 0.4 |
22 | 1.0 | 0.2 |
23 | 1.7 | 0.5 |
24 | 1.9 | 0.2 |
25 | 1.6 | 0.2 |
26 | 1.6 | 0.4 |
27 | 1.5 | 0.2 |
28 | 1.4 | 0.2 |
29 | 1.6 | 0.2 |
... | ... | ... |
120 | 5.7 | 2.3 |
121 | 4.9 | 2.0 |
122 | 6.7 | 2.0 |
123 | 4.9 | 1.8 |
124 | 5.7 | 2.1 |
125 | 6.0 | 1.8 |
126 | 4.8 | 1.8 |
127 | 4.9 | 1.8 |
128 | 5.6 | 2.1 |
129 | 5.8 | 1.6 |
130 | 6.1 | 1.9 |
131 | 6.4 | 2.0 |
132 | 5.6 | 2.2 |
133 | 5.1 | 1.5 |
134 | 5.6 | 1.4 |
135 | 6.1 | 2.3 |
136 | 5.6 | 2.4 |
137 | 5.5 | 1.8 |
138 | 4.8 | 1.8 |
139 | 5.4 | 2.1 |
140 | 5.6 | 2.4 |
141 | 5.1 | 2.3 |
142 | 5.1 | 1.9 |
143 | 5.9 | 2.3 |
144 | 5.7 | 2.5 |
145 | 5.2 | 2.3 |
146 | 5.0 | 1.9 |
147 | 5.2 | 2.0 |
148 | 5.4 | 2.3 |
149 | 5.1 | 1.8 |
150 rows × 2 columns
In [37]:
y = pd.Categorical(data[4]).codes # 直接将数据特征转换为0,1,2
y
Out[37]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int8)
In [38]:
print("总样本数目:%d;特征属性数目:%d" % x.shape)
总样本数目:150;特征属性数目:2
In [40]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=14)
print("训练数据集样本数目:%d, 测试数据集样本数目:%d" % (x_train.shape[0], x_test.shape[0]))
训练数据集样本数目:120, 测试数据集样本数目:30
In [41]:
clf = Pipeline([
('sc', StandardScaler()), # 标准化,把它转化成了高斯分布
('poly', PolynomialFeatures(degree=1)),
('clf', GaussianNB())]) # MultinomialNB多项式贝叶斯算法中要求特征属性的取值不能为负数
# 训练模型
clf.fit(x_train, y_train)
Out[41]:
Pipeline(memory=None, steps=[('sc', StandardScaler(copy=True, with_mean=True, with_std=True)), ('poly', PolynomialFeatures(degree=1, include_bias=True, interaction_only=False)), ('clf', GaussianNB(priors=None, var_smoothing=1e-09))])
In [42]:
y_train_hat = clf.predict(x_train)
print('训练集准确度: %.2f%%' % (100 * accuracy_score(y_train, y_train_hat)))
y_test_hat = clf.predict(x_test)
print('测试集准确度:%.2f%%' % (100 * accuracy_score(y_test, y_test_hat)))
训练集准确度: 95.83% 测试集准确度:96.67%