zoukankan      html  css  js  c++  java
  • 01-赵志勇机器学习-Logistics_Regression-train

    Logistics Regression

    二分类问题。

    模型 线性模型
    响应 sigmoid
    损失函数(显示) 最小均方
    优化方法 BGD

    例子:

    #coding utf-8
    import numpy as np
    
    def load_data(file_name):
      feature_data = []
      label_data = []
      
      f = open(file_name) # 打开文件
      for line in f.readlines():
        # @ strip() 去除字符串首尾的空格
    	# @ split("	") 以“	”分割字符串
        lines = line.strip().split("	")
        
        feature_tmp = []
        label_tmp = []
        feature_tmp.append(1) # 偏置项
        
        for i in range(len(lines)-1):
          feature_tmp.append(float(lines[i]))
        label_tmp.append(float(lines[-1]))
        
        feature_data.append(feature_tmp)
        label_data.append(label_tmp)
      
      f.close() # 关闭文件
      
      return np.mat(feature_data), np.mat(label_data)
    
    
    def sig(x):
      return 1.0/(1+np.exp(-x))
    
      
    def compute_error(h, label):
      # @ shape() 获得特征的长度,[0]行数,[1]列数
      n = np.shape(h)[0]
      
      err = 0
      for i in range(n):
        if h[i,0]>0 and (1-h[i,0])>0:
          err -= (label[i,0]*np.log(h[i,0]) 
    	      + (1-label[i,0])*np.log(1-h[i,0]))
        else:
          err -= 0
      
      return err
    
    
    def lr_train_bgd(feature, label, maxCycle, alpha):
      n = np.shape(feature)[1]
      W = np.mat(np.ones((n,1)))
      
      for i in range(maxCycle):
        h = sig(feature*W)
        err = label - h
        if i % 100 == 0:
          print(compute_error(h, label))
        
        W = W + alpha * feature.T * err
      
      return W
    
    def save_model(file_name, W):
      f = open(file_name, "w")
      w_array = []
      n = np.shape(W)[0]
      for i in range(n):
        w_array.append(str(W[i,0]))
      
      f.write("	".join(w_array))
      f.close()
    
    
    if __name__ == "__main__":
      print("load data")
      feature, label = load_data("data.txt")
      print("train")
      w = lr_train_bgd(feature, label, 1000, 0.1)
      print("save")
      save_model("weights2018", w)
    

      

    参考:

    https://blog.csdn.net/google19890102/article/details/77996085

    https://blog.csdn.net/google19890102?viewmode=contents

    https://github.com/zhaozhiyong19890102/Python-Machine-Learning-Algorithm

  • 相关阅读:
    冒泡排序
    数据结构和算法关系
    js获取ifram对象
    java STL
    bufferedReader 乱码问题
    css animation让图标不断旋转
    apply通过实例理解
    jquery.ajaxfileupload.js
    JDBC getMetaData将结果集组装到List
    Android开发之使用BaseAdapter的notifyDataSetChanged()无法更新列表
  • 原文地址:https://www.cnblogs.com/alexYuin/p/8900028.html
Copyright © 2011-2022 走看看