最近因为工作得原因,接触了一点机器学习中得算法,在此记录下来,以供学习。
import numpy as np import copy import matplotlib.pyplot as plt pic = plt.imread('apple.png') plt.imshow(pic) pic.shape data = pic.reshape(-1,4) def kmeans_wave(n, k, data): #n为迭代次数, k为聚类中心, data为输入数据 data_new = copy.deepcopy(data) data_new = np.column_stack((data_new, np.ones(631*982))) center_point = np.random.choice(631*982, k, replace = False) center = data_new[center_point, :] distance = [[] for i in range(k)] for i in range(n): for j in range(k): distance[j] = np.sqrt(np.sum(np.square(data_new-np.array(center[j])), axis=1)) # 更新距离 data_new[:,4] = np.argmin(np.array(distance), axis = 0) # 将最小距离的类别标签作为当前数据的类别 for l in range(k): center[l] = np.mean(data_new[data_new[:,4]==1], axis=0)# 更新聚类中心
return data_new if __name__ == '__main__': data_new = kmeans_wave(100,6,data) print(data_new.shape) pic_new = data_new[:, 4].reshape(631,982) plt.imshow(pic_new) plt.show()
下面是运行结果: