zoukankan      html  css  js  c++  java
  • tensorflow kmeans 聚类

     iris:

    # -*- coding: utf-8 -*-
    # K-means with TensorFlow
    #----------------------------------
    #
    # This script shows how to do k-means with TensorFlow
    
    import numpy as np
    import matplotlib.pyplot as plt
    import tensorflow as tf
    from sklearn import datasets
    from scipy.spatial import cKDTree
    from sklearn.decomposition import PCA
    from sklearn.preprocessing import scale
    from tensorflow.python.framework import ops
    ops.reset_default_graph()
    
    sess = tf.Session()
    
    iris = datasets.load_iris()
    
    num_pts = len(iris.data)
    num_feats = len(iris.data[0])
    
    # Set k-means parameters
    # There are 3 types of iris flowers, see if we can predict them
    k=3 
    generations = 25
    
    data_points = tf.Variable(iris.data)
    cluster_labels = tf.Variable(tf.zeros([num_pts], dtype=tf.int64))
    
    # Randomly choose starting points
    rand_starts = np.array([iris.data[np.random.choice(len(iris.data))] for _ in range(k)])
    
    centroids = tf.Variable(rand_starts)
    
    # In order to calculate the distance between every data point and every centroid, we
    #  repeat the centroids into a (num_points) by k matrix.
    centroid_matrix = tf.reshape(tf.tile(centroids, [num_pts, 1]), [num_pts, k, num_feats])
    # Then we reshape the data points into k (3) repeats
    point_matrix = tf.reshape(tf.tile(data_points, [1, k]), [num_pts, k, num_feats])
    distances = tf.reduce_sum(tf.square(point_matrix - centroid_matrix), axis=2)
    
    #Find the group it belongs to with tf.argmin()
    centroid_group = tf.argmin(distances, 1)
    
    # Find the group average
    def data_group_avg(group_ids, data):
        # Sum each group
        sum_total = tf.unsorted_segment_sum(data, group_ids, 3)
        # Count each group
        num_total = tf.unsorted_segment_sum(tf.ones_like(data), group_ids, 3)
        # Calculate average
        avg_by_group = sum_total/num_total
        return(avg_by_group)
    
    means = data_group_avg(centroid_group, data_points)
    
    update = tf.group(centroids.assign(means), cluster_labels.assign(centroid_group))
    
    init = tf.global_variables_initializer()
    
    sess.run(init)
    
    for i in range(generations):
        print('Calculating gen {}, out of {}.'.format(i, generations))
        _, centroid_group_count = sess.run([update, centroid_group])
        group_count = []
        for ix in range(k):
            group_count.append(np.sum(centroid_group_count==ix))
        print('Group counts: {}'.format(group_count))
        
    
    [centers, assignments] = sess.run([centroids, cluster_labels])
    
    # Find which group assignments correspond to which group labels
    # First, need a most common element function
    def most_common(my_list):
        return(max(set(my_list), key=my_list.count))
    
    label0 = most_common(list(assignments[0:50]))
    label1 = most_common(list(assignments[50:100]))
    label2 = most_common(list(assignments[100:150]))
    
    group0_count = np.sum(assignments[0:50]==label0)
    group1_count = np.sum(assignments[50:100]==label1)
    group2_count = np.sum(assignments[100:150]==label2)
    
    accuracy = (group0_count + group1_count + group2_count)/150.
    
    print('Accuracy: {:.2}'.format(accuracy))
    
    # Also plot the output
    # First use PCA to transform the 4-dimensional data into 2-dimensions
    pca_model = PCA(n_components=2)
    reduced_data = pca_model.fit_transform(iris.data)
    # Transform centers
    reduced_centers = pca_model.transform(centers)
    
    # Step size of mesh for plotting
    h = .02
    
    # Plot the decision boundary. For that, we will assign a color to each
    x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
    y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    
    # Get k-means classifications for the grid points
    xx_pt = list(xx.ravel())
    yy_pt = list(yy.ravel())
    xy_pts = np.array([[x,y] for x,y in zip(xx_pt, yy_pt)])
    mytree = cKDTree(reduced_centers)
    dist, indexes = mytree.query(xy_pts)
    
    # Put the result into a color plot
    indexes = indexes.reshape(xx.shape)
    plt.figure(1)
    plt.clf()
    plt.imshow(indexes, interpolation='nearest',
               extent=(xx.min(), xx.max(), yy.min(), yy.max()),
               cmap=plt.cm.Paired,
               aspect='auto', origin='lower')
    
    # Plot each of the true iris data groups
    symbols = ['o', '^', 'D']
    label_name = ['Setosa', 'Versicolour', 'Virginica']
    for i in range(3):
        temp_group = reduced_data[(i*50):(50)*(i+1)]
        plt.plot(temp_group[:, 0], temp_group[:, 1], symbols[i], markersize=10, label=label_name[i])
    # Plot the centroids as a white X
    plt.scatter(reduced_centers[:, 0], reduced_centers[:, 1],
                marker='x', s=169, linewidths=3,
                color='w', zorder=10)
    plt.title('K-means clustering on Iris Dataset
    '
              'Centroids are marked with white cross')
    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.legend(loc='lower right')
    plt.show()
    

     

  • 相关阅读:
    Springboot在IDEA中执行,开启热部署
    java 二分法学习
    python爬虫初级--获取指定页面上的菜单名称以及链接,然后导出
    使用POI读取xlsx文件,包含对excel中自定义时间格式的处理
    Linux CentOS6.8 项目部署脚本实现
    Springboot启动源码详解
    FreeMarker基础语法,宏,引用 等
    在eclipse中启动java程序的时候,每次都会在一个未设置断点的源码里面,卡断点
    Vue学习--
    阿里的maven镜像仓库,eclipse中使用maven下载jar包的时候提升速度
  • 原文地址:https://www.cnblogs.com/bonelee/p/9011684.html
Copyright © 2011-2022 走看看