zoukankan      html  css  js  c++  java
  • 李宏毅深度学习第二次作业 Logistic regression 预测年薪超过50W

     1 import pandas as pd
     2 import numpy as np
     3 '''
     4 整体和PM2.5差不多
     5 参考博客:https://www.cnblogs.com/HL-space/p/10785225.html
     6 https://www.cnblogs.com/tingtin/p/12321465.html
     7 '''
     8 epsilon = 1e-5
     9 def train(x_train,y_train,epoch):
    10     num =x_train.shape[0]#row
    11     feat= x_train.shape[1]#col
    12     bias  = 0
    13     w = np.ones(feat)
    14     lr =1
    15     reg_rate=0.001
    16     b_sum=0
    17     w_sum = np.zeros(feat)
    18 
    19     for i in range(epoch):
    20         b_ =0
    21         w_ = np.zeros(feat)
    22         for j in range(num):
    23             y = w.dot(x_train[j,:])+bias
    24             sig = 1/(1+np.exp(-y))
    25             b_ += (-1)*(y_train[j]-sig)
    26             for k in range(feat):
    27                 w_[k] += (-1)*(y_train[j]-sig)*x_train[j,k]+2*reg_rate*w[k]#加入正则化
    28         b_/=num
    29         w_/=num
    30 
    31         b_sum+=b_**2
    32         w_sum+=w_**2
    33 
    34 
    35         bias-=lr/b_sum**0.5*b_
    36         w-=lr/w_sum**0.5*w_
    37 
    38 
    39         if i%3==0:
    40             loss = 0
    41             acc  =0
    42             result = np.zeros(num)
    43             for j in range(num):
    44                 y = w.dot(x_train[j,:])+bias
    45                 sig =1/(1+np.exp(-y))
    46                 if sig >=0.5:#大于0.5认为年薪>50W
    47                     result[j] =1
    48                 else:
    49                     result[j] = 0
    50                 if result[j] ==y_train[j]:
    51                     acc+=1.0
    52                #log(x) x接近0可能溢出,那么+1e-5
    53                 loss+=(-1)*(y_train[j]*np.log(sig+epsilon) +(1-y_train[j]*np.log(1-sig+epsilon)))#1-sig后面也要加1e-5
    54             print('after {} epochs, the loss on train data is:'.format(i), loss / num)
    55             print('after {} epochs,the acc on train data is:'.format(i), acc / num)
    56 
    57 
    58     return w,bias
    59 
    60 
    61 
    62 
    63 def val(x_val,y_val,w,bias):
    64     num = x_val.shape[0]#500
    65     acc = 0
    66     result = np.zeros(num)
    67     for j in range(num):
    68         y = w.dot(x_val[j, :]) + bias
    69         sig = 1 / (1 + np.exp(-y))
    70         if sig >= 0.5:
    71             result[j] = 1
    72         else:
    73             result[j] = 0
    74         if result[j] == y_val[j]:
    75             acc += 1.0
    76     return  acc/num
    77 
    78 def main():
    79     cs = pd.read_csv('train.csv')
    80     cs  = cs.fillna(0)## 用一个数字(此处用0)填充缺失值
    81 
    82     array = np.array(cs)
    83 
    84     x = array[:,1:-1]#第二列到倒数第二列
    85     x[:,-1]/=np.mean(x[:,-1])#x[]的最后一列的值均除以该列的均值
    86     x[:, -2] /= np.mean(x[:, -2])
    87 
    88     y = array[:,-1]#取array的最后一列
    89     x_train, x_val = x[0:3500,:],x[3500:4000,:]
    90     y_train,y_val = y[0:3500],y[3500:4000]
    91     epoch = 30
    92     w,b  = train(x_train,y_train,epoch)
    93     acc = val(x_val,y_val,w,b)
    94     print('The acc on test data is: ',acc)
    95 
    96 
    97 if __name__ =='__main__':
    98     main()

     

    数据集下载

    链接: https://pan.baidu.com/s/10v3I-nCi9yM8Mc0IJRBmaA 提取码: hyje 

  • 相关阅读:
    算法笔记 --- Scale Sort
    算法笔记 --- String Rotation
    Css3动画缩放
    第一天
    SpringMVC_Controller中方法的返回值
    SpringMVC_url-pattern的写法
    SpringMVC_注释编写SpringMvc程序,RequestMapping常用属性,请求提交方式,请求携带参数
    SpringMVC_数据校验
    SpringMVC_类型转换器
    SpringMVC_异常处理的三种方式
  • 原文地址:https://www.cnblogs.com/tingtin/p/12327456.html
Copyright © 2011-2022 走看看