zoukankan      html  css  js  c++  java
  • logisticregression

      1 from numpy import *
      2 import random
      3 import time
      4 st = time.time()
      5 
      6 def loaddata(filename):
      7     fr = open(''.join([filename, '.txt'])).readlines()
      8     trainx = [[1] + map(float, line.split()[:-1]) for line in fr] # trainx = [[1,12.2,22.4],[1,22.3,31.2],...]
      9     trainy = [[float(line.split()[-1])] for line in fr] # trainy = [0,1,1,0,...]
     10     return trainx, trainy
     11 
     12 def sigmod(z):
     13     return 1.0 / (1 + exp(-z))
     14 
     15 def optimizaion(trainx, trainy):
     16     trainxmat = mat(trainx)
     17     m = len(trainx)
     18     # beta = [0,0,0]
     19     beta = ones((len(trainx[0]),1)) # array
     20     # maxiter
     21     M = 500
     22     """
     23     # error permid
     24     e = 
     25     """
     26     """
     27     for i in xrange(M):
     28         #if error2sum > e:
     29         # z = betat.T * x = trainx (matricdoc)* beta = [beta.Tx1,beta.Tx2,...,beta.Txn]
     30         sigmodz = sigmod(trainxmat * beta)
     31         # [error_i = yi - sigmod(zi)]
     32         error = trainy - sigmodz
     33         # update beta
     34         beta += alpha * trainxmat.T * error
     35         print beta
     36         """
     37     # random gradascent
     38     for j in xrange(M):
     39         for i in xrange(m):
     40             # per span
     41             alpha = 0.01 + 4 / (1.0 + i +j)
     42             randid = random.randint(0, m - 1)
     43             sigmodz = sigmod(trainxmat[randid] * beta)
     44             error = trainy[randid] - sigmodz
     45             beta += alpha * trainxmat[randid].T * error
     46             #print beta
     47 
     48     return beta
     49 
     50 
     51 def logregress(testx, beta):
     52     if mat(testx) * beta > 0: return [1.0]
     53     else: return [0.0]
     54 
     55 def main():
     56     # step 1: loading data...
     57     print "step 1: loading data..."
     58     trainx, trainy = loaddata('horseColicTraining')
     59     testx, testy = loaddata('horseColicTest')
     60     """
     61     print 'trainx', trainx
     62     print 'trainy', trainy
     63     print 'testx', testx
     64     print 'testy', testy
     65     print 'testy[2]',testy[2]
     66     """
     67 
     68     # step 2: training...
     69     print "step 2: training..."
     70     beta = optimizaion(trainx, trainy)
     71     #print "beta = ",beta
     72 
     73     # step 3: testing...
     74     print "step 3: testing..."
     75     numTests = 10; errorSum = 0.0; l = len(testx)
     76     for j in xrange(numTests):
     77         errorcount = 0.0
     78         #print 'the total number is: ',l
     79         for i in xrange(l):
     80             if logregress(testx[i], beta) != testy[i]: 
     81                 errorcount += 1
     82         #print "the number of error is: ", errorcount
     83         print "the error rate is: ", (errorcount / l)
     84         errorSum += (errorcount / l)
     85     print "after %d iterations the average error rate is: %f" %(numTests, errorSum/numTests)
     86 
     87 
     88 
     89 """
     90 trainx, trainy = loaddata('testSet')
     91 print trainy
     92 optimizaion(trainx, trainy)
     93 """
     94 
     95 main()
     96 
     97 print "cost time: ", (time.time() - st)
     98 
     99 """ lineregres
    100         # ssi = sigmod(zi) - sigmod(zi) ** 2
    101         ss = [sigmodzi - sigmodzi ** 2 for sigmodzi in sigmodz]
    102         # errssi = errori * ssi
    103         errss = map(lambda x, y: x * y, error, ss)
    104         # treri = errssi * trainxi(vector)
    105         trer = [errss[i] * array(trainx[i]) for i in xrange(m)]
    106         """
  • 相关阅读:
    Delphi 通过Access Violation地址错误找到错误的哪行代码
    GitHub 转载:github删除repository
    GitHub 转载:github的高级搜索
    SVN 转载:svn报错:privious operation has not finshed;run 'cleanup' if it was interrupted
    GitHub 转载:github新手使用
    Delphi 对应JAVA的MD5加密处理
    Delphi 对应JAVA的BASE64位加密处理
    Delphi 对应JAVA的URL编码处理
    python基础(五)
    DataFrame
  • 原文地址:https://www.cnblogs.com/monne/p/4251804.html
Copyright © 2011-2022 走看看