zoukankan      html  css  js  c++  java
  • KMEANS手写

    #coding=utf-8
    import numpy as np
    import os
    import random
    from tfidf_model import TfIdf
    
    # a = np.array([0, 1, 2, 3, 1, 2, 2])
    # print(a[[2,4,2,2,2,2]])
    # # 设置随机种子
    # random.seed(4)
    
    class Kmeans:
        def __init__(self,doc,k,max_iter):
            self.doc = doc
            self.k = k
            self.max_iter = max_iter
            self.tf_idf = TfIdf(doc)
            self.tf_idf.cal_tfidf()
        
        def train(self):
            # 1、随机初始化k个蔟中心
            cluster_center = {i:self.tf_idf.tfidf[i] for i in range(self.k)}
            kmean_iter = 1
            while True:
                # 2、计算每篇文本k个蔟中心的距离
                doc_dist = np.array([[self.cal_dist(cluster_center[i],sent) for i in range(self.k)] for sent in self.tf_idf.tfidf])
                doc_dist_argsort = np.argmax(doc_dist,axis=1)     # 每篇文本和他最近的蔟中心
                # 3、把数据划分到对应的蔟集合
                cluster_set = {i:np.argwhere(doc_dist_argsort==i).reshape(-1) for i in range(self.k)}
                # 4、重新计算蔟中心
                cluster_center = {i:np.mean(self.tf_idf.tfidf[cluster_set[i]],axis=0)  for  i in range(self.k)}
                # 5、设置停止条件
                if kmean_iter>self.max_iter:
                    break
                kmean_iter += 1  
                print(kmean_iter)
                print(cluster_set)
            
        @staticmethod
        def cal_dist(vec1,vec2):
            return round(np.dot(vec1,vec2)/np.sqrt((np.dot(vec1,vec1)*np.dot(vec2,vec2))),4)     
        
    if __name__ == "__main__":
        doc_dir = 'test_text'
        doc = []
        for file_name in os.listdir(doc_dir):
            file_path = os.path.join(doc_dir,file_name)
            with open(file_path,encoding="utf-8") as f:
                doc.append(f.read())
        kmeans = Kmeans(doc, 3, 100)  
        kmeans.train()
        
        
  • 相关阅读:
    为什么linux有足够的内存还进行swap?
    vmstat命令的使用
    Windows远程服务器不能复制粘贴
    Windows可以ping通百度,但是用浏览器打不开网页
    java形式参数分别是基本类型和引用类型的调用
    Ubuntu16.04安装Python3.6 和pip
    Python2/3共存,pip2/3共存
    multiprocessing模块
    Python-进程与线程
    鼠标不能动,插上了但没反应
  • 原文地址:https://www.cnblogs.com/xiaoruirui/p/15625068.html
Copyright © 2011-2022 走看看