zoukankan      html  css  js  c++  java
  • 使用numpy实现机器学习模型

    LR:

    import numpy as np
    import matplotlib.pyplot as plt
    
    def gene_dataset(opt='linear'):
        pos_num , neg_num = 100, 100
        X = np.zeros((2,pos_num+neg_num))
        Y = np.zeros((1,pos_num+neg_num))
        
        if opt == 'linear':
            x1 = np.random.normal(loc=-1,scale=3,size=(1,pos_num)) # 正态分布 均值 方差 样本个数
            
            X[0,:pos_num] = x1
            X[1,:pos_num] = 2*x1+10+0.1*x1**2 + np.random.normal(loc=0,scale=5,size=(1,pos_num)) 
            Y[0,:pos_num] = 1
    
            x2 = np.random.normal(loc=1,scale=3,size=(1,neg_num)) # 正态分布 均值 方差 样本个数
            
            X[0,pos_num:] = x1
            X[1,pos_num:] = 2*x1-5-0.1*x1**2 + np.random.normal(loc=0,scale=5,size=(1,neg_num)) 
            Y[0,pos_num:] = 0
            
        return X,Y
    
    def plotData(X,Y):
        
        plt.figure()
        pos_idx = (Y==1);
        pos_idx = pos_idx[0,:];
        neg_idx = (Y==0);
        neg_idx = neg_idx[0,:];
        plt.plot(X[0,pos_idx],X[1,pos_idx],'r+')
        plt.plot(X[0,neg_idx],X[1,neg_idx],'bo')
        
    X,Y= gene_dataset()
    plotData(X,Y)
    def sigmoid(X):
        return 1/(1+np.exp(-X))
    
    def loss_function(Y, P):
        return Y*np.log(P) + (1-Y)*np.log(1-P)
    
    def cost_function(Y, P):
        m = Y.shape[1]
        return -np.sum(loss_function(Y,P))/m
    
    '''
    dw = αL/αw = (sigmoid(w*x+b)-y)*x
    db = αL/αw = sigmoid(w*x+b)-y
    '''
    def LR(X,Y,alpha,w,b,epoches):
        m = Y.shape[1]
        for epoch in range(epoches):
            Z = sigmoid(np.dot(w.transpose(), X) + b) # sigmoid(wT*X+b) w (n,1) X (n,m) b = (1) => (1, m)
            dw = np.dot(X, (Z-Y).transpose())/m  #   X (n, m) Z-Y (1, m) =>  (n,1)
            db = np.sum(Z-Y)/m
            w = w-alpha*dw
            b = b-alpha*db
            if epoch%10==0:
                print(cost_function(Y, Z))
        return w,b
    
    def pred_func(X, w, b):
        return sigmoid(np.dot(w.transpose(), X) + b)>=0.5
    def init_wb(featnum):
        return np.zeros((featnum,1)) , 0
    
    X,Y= gene_dataset()
    
    plotData(X,Y)
    w,b = init_wb(X.shape[0])
    
    w,b = LR(X,Y,0.1,w,b,50)
    
    pred = pred_func(X, w, b)
    
    print(pred)
    print(w, b)

    Kmeans:

    import numpy as np
    import matplotlib.pyplot as plt
    from copy import deepcopy
    
    def distance(v1, v2):
        return np.sum((v1-v2)**2)**0.5
    
    def init(X, k):
        m, n = X.shape[0], X.shape[1]
        center = np.zeros((k,n))
        idx = [i for i in range(m)]
        np.random.shuffle(idx)
        for i in range(k):
            center[i,:] = X[idx[i],:]
        return center
        
    def kmeans(X, k):
        center = init(X,k)
        clusterChanged = True
        while clusterChanged:
            clusterChanged = False
            clusterDict = {}
            precenter = deepcopy(center)
            for i in range(len(X)):
                MIN_dist = 1e9
                MIN_INDEX = 0
                for j in range(k):
                    if distance(center[j,:] , X[i,:]) < MIN_dist:
                        MIN_dist = distance(center[j,:] , X[i,:])
                        MIN_INDEX = j
                
                if MIN_INDEX not in clusterDict: clusterDict[MIN_INDEX] = []
                clusterDict[MIN_INDEX].append(X[i])
            
            for i in range(k):
                center[i,:] = np.mean(clusterDict[i],axis=0) # axis=0对横轴进行操作
                
            if np.allclose(center, precenter):
                clusterChanged = True
        
        return center
    
    X = np.array([[1,3],[1,4],[2,3],[2,4],[3,1],[3,2],[4,1],[4,2]])*1.0
    print(X)
                    
    print(kmeans(X,2))          
  • 相关阅读:
    构建之法:第二次心得
    构建之法:第一次心得
    tomcat配置限制ip和建立图片服务器
    tomcat8.5优化配置
    java 操作 csv文件
    jsoup教学系列
    (转)js实现倒计时效果(年月日时分秒)
    本地启动tomcat的时候报java.util.concurrent.ExecutionException: java.lang.OutOfMemoryError: PermGen space
    使用mybatis执行oracle存储过程
    java 获取web登录者的ip地址
  • 原文地址:https://www.cnblogs.com/liyinggang/p/13489933.html
Copyright © 2011-2022 走看看