zoukankan      html  css  js  c++  java
  • Tensorflow 实现稠密输入数据的逻辑回归二分类

    首先 实现一个尽可能少调用tf.nn模块儿的,自己手写相关的function

       

    import tensorflow as tf

    import numpy as np

    import melt_dataset

    import sys

    from sklearn.metrics import roc_auc_score

       

    def init_weights(shape):

    return tf.Variable(tf.random_normal(shape, stddev=0.01))

       

    def model(X, w):

    return 1.0/(1.0 + tf.exp(-(tf.matmul(X, w)))) #sigmoid

       

    batch_size = 500

    learning_rate = 0.001

    num_iters = 1020

       

    argv = sys.argv

    trainset = argv[1]

    testset = argv[2]

       

    trX, trY = melt_dataset.load_dense_data(trainset)

    print "finish loading train set ",trainset

    teX, teY = melt_dataset.load_dense_data(testset)

    print "finish loading test set ", testset

       

    num_features = trX[0].shape[0]

    print 'num_features: ',num_features

    print 'trainSet size: ', len(trX)

    print 'testSet size: ', len(teX)

    print 'batch_size:', batch_size, ' learning_rate:', learning_rate, ' num_iters:', num_iters

       

    X = tf.placeholder("float", [None, num_features]) # create symbolic variables

    Y = tf.placeholder("float", [None, 1])

       

    w = init_weights([num_features, 1]) # like in linear regression, we need a shared variable weight matrix for logistic regression

       

    py_x = model(X, w)

       

    cost = -tf.reduce_sum(Y*tf.log(py_x) + (1 - Y) * tf.log(1 - py_x))

    train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # construct optimizer

       

    predict_op = py_x

       

    sess = tf.Session()

    init = tf.initialize_all_variables()

    sess.run(init)

       

    for i in range(num_iters):

    predicts, cost_ = sess.run([predict_op, cost], feed_dict={X: teX, Y: teY})

    print i, 'auc:', roc_auc_score(teY, predicts), 'cost:', cost_

    for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX), batch_size)):

    sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})

       

    predicts, cost_ = sess.run([predict_op, cost], feed_dict={X: teX, Y: teY})

    print 'final ', 'auc:', roc_auc_score(teY, predicts),'cost:', cost_

       

    注意如果设置的batch_size 比较大 learning rate也比较大 可能会出现nan 可以通过减小batch_size

    或者调小learning rate来避免

       

    更好的方式是使用tensorflow自带的函数

    tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None)

       

    Computes sigmoid cross entropy given logits.

       

    Measures the probability error in discrete classification tasks in which each class is independent and not mutually exclusive. For instance, one could perform multilabel classification where a picture can contain both an elephant and a dog at the same time.

       

    For brevity, let x = logits, z = targets. The logistic loss is

    原始的logistic误差,注意和 -tf.reduce_sum(Y*tf.log(py_x) + (1 - Y) * tf.log(1 - py_x)) 是等价的

    x - x * z + log(1 + exp(-x))

       

    Z = 1 , log(1 + exp(-x))

    Z = 0 , x + log(1 + exp(-x) ) = log(1 + exp(x))

    To ensure stability and avoid overflow, the implementation uses

    为了避免溢出。。nan的产生。。

       

    max(x, 0) - x * z + log(1 + exp(-abs(x)))

       

    如果z = 1, x >= 0 和原始一致

    如果z = 1, x < 0 那么 -x + log(1 + exp(x)) = log(1+ exp(x) / exp(x)) = log(1 + exp(-x)) 还是一样。。

    如果z = 0, x <= 0 log(1 + exp(x))

    如果 z = 0, x > 0 x + log(1 + exp(-x)) = log(1 + exp(x))

       

    感觉就是避免了 exp(x) x过大? @TODO

    logits and targets must have the same type and shape.

    注意尽管采用这个避免了nan产生 但是实际看 过程的话 auc 会迭代中变化不稳定 感觉还是调整下learning rate比较好。

       

    最后tf.nn.sigmoid也可以替代手写的

    y = 1 / (1 + exp(-x)).

       

    来自 <http://www.tensorflow.org/api_docs/python/nn.md#sigmoid>

       

    最终版本

    import tensorflow as tf

    import numpy as np

    import melt_dataset

    import sys

    from sklearn.metrics import roc_auc_score

       

    def init_weights(shape):

    return tf.Variable(tf.random_normal(shape, stddev=0.01))

       

    def model(X, w):

    return tf.matmul(X,w)

       

    batch_size = 500

    learning_rate = 0.001

    num_iters = 120

       

    argv = sys.argv

    trainset = argv[1]

    testset = argv[2]

       

    trX, trY = melt_dataset.load_dense_data(trainset)

    print "finish loading train set ",trainset

    teX, teY = melt_dataset.load_dense_data(testset)

    print "finish loading test set ", testset

       

    num_features = trX[0].shape[0]

    print 'num_features: ',num_features

    print 'trainSet size: ', len(trX)

    print 'testSet size: ', len(teX)

    print 'batch_size:', batch_size, ' learning_rate:', learning_rate, ' num_iters:', num_iters

       

    X = tf.placeholder("float", [None, num_features]) # create symbolic variables

    Y = tf.placeholder("float", [None, 1])

       

    w = init_weights([num_features, 1]) # like in linear regression, we need a shared variable weight matrix for logistic regression

       

    py_x = model(X, w)

       

    cost = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(py_x, Y))

    train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # construct optimizer

    predict_op = tf.nn.sigmoid(py_x)

       

    sess = tf.Session()

    init = tf.initialize_all_variables()

    sess.run(init)

       

    for i in range(num_iters):

    predicts, cost_ = sess.run([predict_op, cost], feed_dict={X: teX, Y: teY})

    print i, 'auc:', roc_auc_score(teY, predicts), 'cost:', cost_

    for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX), batch_size)):

    sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})

       

    predicts, cost_ = sess.run([predict_op, cost], feed_dict={X: teX, Y: teY})

    print 'final ', 'auc:', roc_auc_score(teY, predicts),'cost:', cost_

       

       

    运行结果

    ./logistic_regression.py corpus/feature.normed.rand.12000.0_2.txt corpus/feature.normed.rand.12000.1_2.txt

    /home/users/chenghuige/.jumbo/lib/python2.7/site-packages/sklearn/externals/joblib/_multiprocessing_helpers.py:29: UserWarning: This platform lacks a functioning sem_open implementation, therefore, the required synchronization primitives needed will not function, see issue 3770.. joblib will operate in serial mode

    warnings.warn('%s. joblib will operate in serial mode' % (e,))

    ... loading data: corpus/feature.normed.rand.12000.0_2.txt

    10000

    finish loading train set corpus/feature.normed.rand.12000.0_2.txt

    ... loading data: corpus/feature.normed.rand.12000.1_2.txt

    finish loading test set corpus/feature.normed.rand.12000.1_2.txt

    num_features: 493

    trainSet size: 10001

    testSet size: 1999

    batch_size: 500 learning_rate: 0.001 num_iters: 120

    I tensorflow/core/common_runtime/local_device.cc:25] Local device intra op parallelism threads: 24

    I tensorflow/core/common_runtime/local_session.cc:45] Local session inter op parallelism threads: 24

    0 auc: 0.380243958856 cost: 1350.72

    1 auc: 0.876460425134 cost: 538.122

    2 auc: 0.894349010333 cost: 493.974

    3 auc: 0.900608480969 cost: 479.184

    4 auc: 0.904222074311 cost: 471.299

    5 auc: 0.906476144619 cost: 466.076

    6 auc: 0.908203871808 cost: 462.155

    7 auc: 0.909598799088 cost: 458.986

    8 auc: 0.910892234197 cost: 456.31

    9 auc: 0.911966162982 cost: 453.987

    10 auc: 0.912926798182 cost: 451.936

    11 auc: 0.913835507154 cost: 450.105

    12 auc: 0.91452234952 cost: 448.458

    13 auc: 0.915244596132 cost: 446.968

    14 auc: 0.915910195951 cost: 445.613

    15 auc: 0.916483744731 cost: 444.375

    16 auc: 0.916939279358 cost: 443.243

    17 auc: 0.917430218232 cost: 442.202

    18 auc: 0.917977803898 cost: 441.244

    19 auc: 0.918386132865 cost: 440.36

    20 auc: 0.918782660417 cost: 439.541

    21 auc: 0.919139063156 cost: 438.782

    22 auc: 0.919471863066 cost: 438.076

    23 auc: 0.91976925873 cost: 437.418

    24 auc: 0.920017088449 cost: 436.803

    25 auc: 0.920342807509 cost: 436.229

    26 auc: 0.920616600343 cost: 435.69

    27 auc: 0.920784180439 cost: 435.185

    28 auc: 0.920965922233 cost: 434.709

    29 auc: 0.921171266858 cost: 434.261

    30 auc: 0.921402574597 cost: 433.839

    31 auc: 0.921548912146 cost: 433.439

    32 auc: 0.921770778752 cost: 433.061

    33 auc: 0.921959601395 cost: 432.703

    34 auc: 0.922181468002 cost: 432.363

    35 auc: 0.922372650928 cost: 432.04

    36 auc: 0.922474143099 cost: 431.734

    37 auc: 0.922618120365 cost: 431.441

    38 auc: 0.92279278131 cost: 431.163

    39 auc: 0.922929677727 cost: 430.897

    40 auc: 0.92300992735 cost: 430.643

    41 auc: 0.923146823767 cost: 430.401

    42 auc: 0.923253036504 cost: 430.169

    43 auc: 0.92331204358 cost: 429.947

    44 auc: 0.923441859148 cost: 429.735

    45 auc: 0.923557513017 cost: 429.531

    46 auc: 0.923649564056 cost: 429.336

    47 auc: 0.923741615094 cost: 429.148

    48 auc: 0.923826585284 cost: 428.968

    49 auc: 0.923897393775 cost: 428.795

    50 auc: 0.923991805097 cost: 428.629

    51 auc: 0.924067334155 cost: 428.469

    52 auc: 0.924131061797 cost: 428.315

    53 auc: 0.924161745477 cost: 428.166

    54 auc: 0.924220752553 cost: 428.024

    55 auc: 0.924251436232 cost: 427.886

    56 auc: 0.92429628161 cost: 427.754

    57 auc: 0.924355288686 cost: 427.627

    58 auc: 0.924440258876 cost: 427.504

    59 auc: 0.924515787933 cost: 427.385

    60 auc: 0.924560633311 cost: 427.27

    61 auc: 0.92465268435 cost: 427.16

    62 auc: 0.924699890011 cost: 427.053

    63 auc: 0.924780139634 cost: 426.95

    64 auc: 0.924834426144 cost: 426.851

    65 auc: 0.924829705578 cost: 426.755

    66 auc: 0.924867470107 cost: 426.663

    67 auc: 0.924891072937 cost: 426.573

    68 auc: 0.924907594919 cost: 426.487

    69 auc: 0.924950080014 cost: 426.404

    70 auc: 0.924961881429 cost: 426.323

    71 auc: 0.925023248788 cost: 426.245

    72 auc: 0.925051572185 cost: 426.17

    73 auc: 0.925079895581 cost: 426.097

    74 auc: 0.925120020393 cost: 426.027

    75 auc: 0.92517194662 cost: 425.959

    76 auc: 0.925240394828 cost: 425.894

    77 auc: 0.925294681338 cost: 425.83

    78 auc: 0.925330085584 cost: 425.769

    79 auc: 0.925391452943 cost: 425.71

    80 auc: 0.925410335207 cost: 425.653

    81 auc: 0.925436298321 cost: 425.598

    82 auc: 0.925485864265 cost: 425.544

    83 auc: 0.925500025963 cost: 425.493

    84 auc: 0.925530709643 cost: 425.443

    85 auc: 0.92555195219 cost: 425.395

    86 auc: 0.925594437285 cost: 425.349

    87 auc: 0.92563692238 cost: 425.304

    88 auc: 0.925665245776 cost: 425.261

    89 auc: 0.925703010305 cost: 425.219

    90 auc: 0.925721892569 cost: 425.179

    91 auc: 0.925771458513 cost: 425.14

    92 auc: 0.925778539362 cost: 425.103

    93 auc: 0.925778539362 cost: 425.066

    94 auc: 0.925771458513 cost: 425.032

    95 auc: 0.925780899645 cost: 424.998

    96 auc: 0.925809223042 cost: 424.966

    97 auc: 0.925806862759 cost: 424.935

    98 auc: 0.925835186156 cost: 424.905

    99 auc: 0.925851708137 cost: 424.876

    100 auc: 0.925868230118 cost: 424.848

    101 auc: 0.925863509552 cost: 424.821

    102 auc: 0.925887112383 cost: 424.795

    103 auc: 0.925905994647 cost: 424.77

    104 auc: 0.925922516628 cost: 424.747

    105 auc: 0.925915435779 cost: 424.724

    106 auc: 0.925934318043 cost: 424.702

    107 auc: 0.925920156345 cost: 424.681

    108 auc: 0.925965001723 cost: 424.661

    109 auc: 0.925955560591 cost: 424.642

    110 auc: 0.926005126535 cost: 424.623

    111 auc: 0.926026369082 cost: 424.605

    112 auc: 0.926035810214 cost: 424.588

    113 auc: 0.926009847101 cost: 424.572

    114 auc: 0.926000405969 cost: 424.557

    115 auc: 0.926021648516 cost: 424.542

    116 auc: 0.92604053078 cost: 424.528

    117 auc: 0.926057052762 cost: 424.514

    118 auc: 0.926075935026 cost: 424.501

    119 auc: 0.926090096724 cost: 424.489

    final auc: 0.926087736441 cost: 424.478

       

    程序另外一个需要注意的是 Y,label 需要

    [batch_size, 1]这样

    [

    [0],

    [1],

    [0],

    ]

    不能是[0,1,0…]也就是[batch_size,]是不行的

       

    看一下LinearSVMgbdt的结果

    mlt -c tt ./corpus/feature.normed.rand.12000.0_2.txt ./corpus/feature.normed.rand.12000.1_2.tx

    Confusion table:

    ||===============================||

    || PREDICTED ||

    TRUTH || positive | negative || RECALL

    ||===============================||

    positive|| 676 | 532 || 0.5596 (676/1208)

    negative|| 192 | 8601 || 0.9782 (8601/8793)

    ||===============================||

    PRECISION 0.7788 (676/868) 0.9417(8601/9133)

    LOG-LOSS/instance:                0.2412

    LOG-LOSS-PROB/instance:                0.1912

    TEST-SET ENTROPY (prior LL/in):        0.3685

    LOG-LOSS REDUCTION (RIG):        48.1082%

       

    OVERALL 0/1 ACCURACY:        0.9276 (9277/10001)

    POS.PRECISION:                0.7788

    POS.RECALL:                0.5596

    NEG.PRECISION:                0.9417

    NEG.RECALL:                0.9782

    F1.SCORE:                 0.6513

    OuputAUC: 0.9309

    AUC: [0.9309]

       

    mlt -c tt ./corpus/feature.normed.rand.12000.0_2.txt ./corpus/feature.normed.rand.12000.1_2.txt -cl gbdt

    Confusion table:

    ||===============================||

    || PREDICTED ||

    TRUTH || positive | negative || RECALL

    ||===============================||

    positive|| 1194 | 14 || 0.9884 (1194/1208)

    negative|| 0 | 8793 || 1.0000 (8793/8793)

    ||===============================||

    PRECISION 1.0000 (1194/1194) 0.9984(8793/8807)

    LOG-LOSS/instance:                0.0214

    LOG-LOSS-PROB/instance:                0.0097

    TEST-SET ENTROPY (prior LL/in):        0.3685

    LOG-LOSS REDUCTION (RIG):        97.3625%

       

    OVERALL 0/1 ACCURACY:        0.9986 (9987/10001)

    POS.PRECISION:                1.0000

    POS.RECALL:                0.9884

    NEG.PRECISION:                0.9984

    NEG.RECALL:                1.0000

    F1.SCORE:                 0.9942

    OuputAUC: 0.9988

    AUC: [0.9988]

       

       

  • 相关阅读:
    Spring Security 源码解析(一)AbstractAuthenticationProcessingFilter
    Spring OAuth2 GitHub 自定义登录信息
    var 在异步中引发的 bug
    LeetCode
    LeetCode
    go日期时间函数+常用内建函数+错误处理
    golang字符串常用函数
    syntax error: non-declaration statement outside function body
    Redis基操
    复习JavaScript随手记
  • 原文地址:https://www.cnblogs.com/rocketfan/p/4984022.html
Copyright © 2011-2022 走看看