zoukankan      html  css  js  c++  java
  • TensorFlow实现线性回归模型代码

    模型构建

    1.示例代码linear_regression_model.py

    #!/usr/bin/python
    # -*- coding: utf-8 -*
    import tensorflow as tf
    import numpy as np
    
    class linearRegressionModel:
    
      def __init__(self,x_dimen):
        self.x_dimen = x_dimen
        self._index_in_epoch = 0
        self.constructModel()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
    
      #权重初始化
      def weight_variable(self,shape):
        initial = tf.truncated_normal(shape,stddev = 0.1)
        return tf.Variable(initial)
    
      #偏置项初始化
      def bias_variable(self,shape):
        initial = tf.constant(0.1,shape = shape)
        return tf.Variable(initial)
    
      #每次选取100个样本,如果选完,重新打乱
      def next_batch(self,batch_size):
        start = self._index_in_epoch
        self._index_in_epoch += batch_size
        if self._index_in_epoch > self._num_datas:
            perm = np.arange(self._num_datas)
            np.random.shuffle(perm)
            self._datas = self._datas[perm]
            self._labels = self._labels[perm]
            start = 0
            self._index_in_epoch = batch_size
            assert batch_size <= self._num_datas
        end = self._index_in_epoch
        return self._datas[start:end],self._labels[start:end]
    
      def constructModel(self):
        self.x = tf.placeholder(tf.float32, [None,self.x_dimen])
        self.y = tf.placeholder(tf.float32,[None,1])
        self.w = self.weight_variable([self.x_dimen,1])
        self.b = self.bias_variable([1])
        self.y_prec = tf.nn.bias_add(tf.matmul(self.x, self.w), self.b)
    
        mse = tf.reduce_mean(tf.squared_difference(self.y_prec, self.y))
        l2 = tf.reduce_mean(tf.square(self.w))
        self.loss = mse + 0.15*l2
        self.train_step = tf.train.AdamOptimizer(0.1).minimize(self.loss)
    
      def train(self,x_train,y_train,x_test,y_test):
        self._datas = x_train
        self._labels = y_train
        self._num_datas = x_train.shape[0]
        for i in range(5000):
            batch = self.next_batch(100)
            self.sess.run(self.train_step,feed_dict={self.x:batch[0],self.y:batch[1]})
            if i%10 == 0:
                train_loss = self.sess.run(self.loss,feed_dict={self.x:batch[0],self.y:batch[1]})
                print('step %d,test_loss %f' % (i,train_loss))
    
      def predict_batch(self,arr,batch_size):
        for i in range(0,len(arr),batch_size):
            yield arr[i:i + batch_size]
    
      def predict(self, x_predict):
        pred_list = []
        for x_test_batch in self.predict_batch(x_predict,100):
          pred = self.sess.run(self.y_prec, {self.x:x_test_batch})
          pred_list.append(pred)
        return np.vstack(pred_list)
    

     2.创建run.py

    #!/usr/bin/python
    # -*- coding: utf-8 -*
    
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import r2_score
    from sklearn.datasets import make_regression
    from sklearn.linear_model import LinearRegression
    from linear_regression_model import linearRegressionModel as lrm
    
    if __name__ == '__main__':
        x, y = make_regression(7000)
        x_train,x_test,y_train, y_test = train_test_split(x, y, test_size=0.5)
        y_lrm_train = y_train.reshape(-1, 1)
        y_lrm_test = y_test.reshape(-1, 1)
    
        linear = lrm(x.shape[1])
        linear.train(x_train, y_lrm_train,x_test,y_lrm_test)
        y_predict = linear.predict(x_test)
        print("Tensorflow R2: ", r2_score(y_predict.ravel(), y_lrm_test.ravel()))
    
        lr = LinearRegression()
        y_predict = lr.fit(x_train, y_train).predict(x_test)
        print("Sklearn R2: ", r2_score(y_predict, y_test)) #采用r2_score评分函数
    

     运行结果:

    step 0,test_loss 27078.781250
    step 10,test_loss 29246.253906
    step 20,test_loss 21168.052734
    step 30,test_loss 22109.154297
    step 40,test_loss 28030.435547
    step 50,test_loss 24265.765625
    step 60,test_loss 28433.816406
    step 70,test_loss 24395.164062
    step 80,test_loss 19135.515625
    step 90,test_loss 20932.734375
    step 100,test_loss 17176.033203
    step 110,test_loss 19729.275391
    step 120,test_loss 18076.587891
    step 130,test_loss 24546.722656
    step 140,test_loss 22370.619141
    step 150,test_loss 17227.343750
    step 160,test_loss 21498.363281
    step 170,test_loss 17482.292969
    step 180,test_loss 16188.901367
    step 190,test_loss 17961.816406
    step 200,test_loss 15168.850586
    step 210,test_loss 14205.447266
    step 220,test_loss 15992.610352
    step 230,test_loss 12878.104492
    step 240,test_loss 15663.670898
    step 250,test_loss 11105.211914
    step 260,test_loss 11135.759766
    step 270,test_loss 12083.872070
    step 280,test_loss 9544.156250
    step 290,test_loss 12040.689453
    step 300,test_loss 8685.537109
    step 310,test_loss 11533.030273
    step 320,test_loss 11031.776367
    step 330,test_loss 11258.272461
    step 340,test_loss 9219.499023
    step 350,test_loss 7839.248047
    step 360,test_loss 9757.743164
    step 370,test_loss 7579.228027
    step 380,test_loss 8326.705078
    step 390,test_loss 8823.761719
    step 400,test_loss 8431.373047
    step 410,test_loss 8025.544922
    step 420,test_loss 7954.462891
    step 430,test_loss 9809.444336
    step 440,test_loss 5645.476074
    step 450,test_loss 7813.232422
    step 460,test_loss 6410.347656
    step 470,test_loss 6623.901367
    step 480,test_loss 7697.770508
    step 490,test_loss 5924.088867
    step 500,test_loss 5174.365234
    step 510,test_loss 5223.140625
    step 520,test_loss 5655.796387
    step 530,test_loss 4949.434570
    step 540,test_loss 4330.499023
    step 550,test_loss 5321.663086
    step 560,test_loss 4629.940918
    step 570,test_loss 3220.557373
    step 580,test_loss 4162.278320
    step 590,test_loss 4546.246582
    step 600,test_loss 4487.117188
    step 610,test_loss 5037.617676
    step 620,test_loss 3526.248047
    step 630,test_loss 3432.793457
    step 640,test_loss 3385.915527
    step 650,test_loss 3272.809814
    step 660,test_loss 2710.681396
    step 670,test_loss 3326.879883
    step 680,test_loss 3275.361084
    step 690,test_loss 2347.117432
    step 700,test_loss 2957.036621
    step 710,test_loss 1699.123535
    step 720,test_loss 2293.731445
    step 730,test_loss 2275.772705
    step 740,test_loss 2176.456055
    step 750,test_loss 2457.974121
    step 760,test_loss 2203.473877
    step 770,test_loss 1920.002686
    step 780,test_loss 2047.632446
    step 790,test_loss 1736.505615
    step 800,test_loss 2039.262451
    step 810,test_loss 2055.947510
    step 820,test_loss 1908.234375
    step 830,test_loss 1280.326904
    step 840,test_loss 1412.927856
    step 850,test_loss 1737.114258
    step 860,test_loss 1251.464111
    step 870,test_loss 1589.670532
    step 880,test_loss 1396.735474
    step 890,test_loss 1706.040527
    step 900,test_loss 1558.866333
    step 910,test_loss 1334.543213
    step 920,test_loss 1306.657471
    step 930,test_loss 942.939819
    step 940,test_loss 1200.833008
    step 950,test_loss 932.249695
    step 960,test_loss 1328.827271
    step 970,test_loss 1191.408081
    step 980,test_loss 832.388062
    step 990,test_loss 1052.487427
    step 1000,test_loss 896.287964
    step 1010,test_loss 707.095093
    step 1020,test_loss 622.292297
    step 1030,test_loss 798.665649
    step 1040,test_loss 789.424316
    step 1050,test_loss 606.861450
    step 1060,test_loss 573.976074
    step 1070,test_loss 465.951965
    step 1080,test_loss 631.956543
    step 1090,test_loss 679.685913
    step 1100,test_loss 440.278046
    step 1110,test_loss 476.793945
    step 1120,test_loss 450.453278
    step 1130,test_loss 541.740479
    step 1140,test_loss 502.860077
    step 1150,test_loss 363.825653
    step 1160,test_loss 378.313232
    step 1170,test_loss 364.206024
    step 1180,test_loss 359.042999
    step 1190,test_loss 304.770569
    step 1200,test_loss 354.092407
    step 1210,test_loss 296.288147
    step 1220,test_loss 313.082031
    step 1230,test_loss 321.331512
    step 1240,test_loss 327.985718
    step 1250,test_loss 257.409210
    step 1260,test_loss 250.276291
    step 1270,test_loss 191.458878
    step 1280,test_loss 216.972244
    step 1290,test_loss 229.754684
    step 1300,test_loss 219.731140
    step 1310,test_loss 197.320190
    step 1320,test_loss 185.500366
    step 1330,test_loss 180.765671
    step 1340,test_loss 223.783081
    step 1350,test_loss 166.295975
    step 1360,test_loss 146.334641
    step 1370,test_loss 191.004700
    step 1380,test_loss 137.425964
    step 1390,test_loss 155.957443
    step 1400,test_loss 137.031784
    step 1410,test_loss 144.765793
    step 1420,test_loss 123.946625
    step 1430,test_loss 133.717957
    step 1440,test_loss 136.200287
    step 1450,test_loss 109.962036
    step 1460,test_loss 107.478485
    step 1470,test_loss 111.343063
    step 1480,test_loss 113.355667
    step 1490,test_loss 110.620399
    step 1500,test_loss 116.955994
    step 1510,test_loss 102.297958
    step 1520,test_loss 107.474968
    step 1530,test_loss 88.769562
    step 1540,test_loss 88.092247
    step 1550,test_loss 93.228027
    step 1560,test_loss 78.206909
    step 1570,test_loss 99.623810
    step 1580,test_loss 67.202003
    step 1590,test_loss 77.569229
    step 1600,test_loss 78.516144
    step 1610,test_loss 76.165176
    step 1620,test_loss 64.493408
    step 1630,test_loss 70.672768
    step 1640,test_loss 68.577499
    step 1650,test_loss 72.143890
    step 1660,test_loss 63.308643
    step 1670,test_loss 64.004288
    step 1680,test_loss 64.626549
    step 1690,test_loss 59.137959
    step 1700,test_loss 63.122589
    step 1710,test_loss 56.314068
    step 1720,test_loss 51.382557
    step 1730,test_loss 58.105713
    step 1740,test_loss 57.619289
    step 1750,test_loss 54.326633
    step 1760,test_loss 51.271332
    step 1770,test_loss 56.553986
    step 1780,test_loss 51.459373
    step 1790,test_loss 49.371822
    step 1800,test_loss 52.714359
    step 1810,test_loss 50.442295
    step 1820,test_loss 49.796776
    step 1830,test_loss 48.404625
    step 1840,test_loss 47.714275
    step 1850,test_loss 49.141331
    step 1860,test_loss 46.075230
    step 1870,test_loss 47.250427
    step 1880,test_loss 47.220695
    step 1890,test_loss 47.975838
    step 1900,test_loss 47.080906
    step 1910,test_loss 45.991798
    step 1920,test_loss 45.940758
    step 1930,test_loss 45.241516
    step 1940,test_loss 45.457054
    step 1950,test_loss 44.415176
    step 1960,test_loss 44.690414
    step 1970,test_loss 44.910900
    step 1980,test_loss 43.690544
    step 1990,test_loss 42.880653
    step 2000,test_loss 42.956898
    step 2010,test_loss 43.080429
    step 2020,test_loss 43.176693
    step 2030,test_loss 43.030117
    step 2040,test_loss 43.170925
    step 2050,test_loss 42.681801
    step 2060,test_loss 42.610954
    step 2070,test_loss 42.576504
    step 2080,test_loss 42.255066
    step 2090,test_loss 42.081310
    step 2100,test_loss 42.341095
    step 2110,test_loss 42.025223
    step 2120,test_loss 42.204201
    step 2130,test_loss 42.335026
    step 2140,test_loss 41.973049
    step 2150,test_loss 42.003143
    step 2160,test_loss 41.904259
    step 2170,test_loss 41.881233
    step 2180,test_loss 41.608265
    step 2190,test_loss 41.525867
    step 2200,test_loss 41.472271
    step 2210,test_loss 41.472610
    step 2220,test_loss 41.598587
    step 2230,test_loss 41.459789
    step 2240,test_loss 41.376347
    step 2250,test_loss 41.300011
    step 2260,test_loss 41.316811
    step 2270,test_loss 41.432549
    step 2280,test_loss 41.290428
    step 2290,test_loss 41.279583
    step 2300,test_loss 41.197216
    step 2310,test_loss 41.269833
    step 2320,test_loss 41.240284
    step 2330,test_loss 41.202190
    step 2340,test_loss 41.211605
    step 2350,test_loss 41.224072
    step 2360,test_loss 41.169403
    step 2370,test_loss 41.151337
    step 2380,test_loss 41.162971
    step 2390,test_loss 41.127731
    step 2400,test_loss 41.094795
    step 2410,test_loss 41.089066
    step 2420,test_loss 41.137642
    step 2430,test_loss 41.085999
    step 2440,test_loss 41.096901
    step 2450,test_loss 41.096237
    step 2460,test_loss 41.072151
    step 2470,test_loss 41.094440
    step 2480,test_loss 41.049301
    step 2490,test_loss 41.062485
    step 2500,test_loss 41.053036
    step 2510,test_loss 41.042328
    step 2520,test_loss 41.049831
    step 2530,test_loss 41.078171
    step 2540,test_loss 41.013088
    step 2550,test_loss 41.039490
    step 2560,test_loss 41.040127
    step 2570,test_loss 41.047153
    step 2580,test_loss 41.059521
    step 2590,test_loss 41.067646
    step 2600,test_loss 41.027416
    step 2610,test_loss 41.019939
    step 2620,test_loss 41.030586
    step 2630,test_loss 41.028877
    step 2640,test_loss 41.027557
    step 2650,test_loss 41.026352
    step 2660,test_loss 41.023903
    step 2670,test_loss 41.006763
    step 2680,test_loss 41.024330
    step 2690,test_loss 41.046272
    step 2700,test_loss 41.018227
    step 2710,test_loss 41.016628
    step 2720,test_loss 41.025139
    step 2730,test_loss 41.019703
    step 2740,test_loss 41.016834
    step 2750,test_loss 41.033138
    step 2760,test_loss 41.031982
    step 2770,test_loss 41.027203
    step 2780,test_loss 41.036865
    step 2790,test_loss 41.039066
    step 2800,test_loss 41.015831
    step 2810,test_loss 41.021862
    step 2820,test_loss 41.037052
    step 2830,test_loss 41.030590
    step 2840,test_loss 41.026188
    step 2850,test_loss 41.019707
    step 2860,test_loss 41.021141
    step 2870,test_loss 41.019894
    step 2880,test_loss 41.020607
    step 2890,test_loss 41.024086
    step 2900,test_loss 41.037041
    step 2910,test_loss 41.023495
    step 2920,test_loss 41.011646
    step 2930,test_loss 41.022732
    step 2940,test_loss 41.017460
    step 2950,test_loss 41.042557
    step 2960,test_loss 41.025982
    step 2970,test_loss 41.023857
    step 2980,test_loss 41.029766
    step 2990,test_loss 41.021320
    step 3000,test_loss 41.036278
    step 3010,test_loss 41.026100
    step 3020,test_loss 41.029068
    step 3030,test_loss 41.007935
    step 3040,test_loss 41.024139
    step 3050,test_loss 41.023842
    step 3060,test_loss 41.023033
    step 3070,test_loss 41.041313
    step 3080,test_loss 41.013794
    step 3090,test_loss 41.021595
    step 3100,test_loss 41.023506
    step 3110,test_loss 41.027863
    step 3120,test_loss 41.049881
    step 3130,test_loss 41.037209
    step 3140,test_loss 41.013416
    step 3150,test_loss 41.044666
    step 3160,test_loss 41.022858
    step 3170,test_loss 41.026386
    step 3180,test_loss 41.025173
    step 3190,test_loss 41.025276
    step 3200,test_loss 41.031715
    step 3210,test_loss 41.019821
    step 3220,test_loss 41.023750
    step 3230,test_loss 41.026768
    step 3240,test_loss 41.025543
    step 3250,test_loss 41.030800
    step 3260,test_loss 41.032837
    step 3270,test_loss 41.020596
    step 3280,test_loss 41.024185
    step 3290,test_loss 41.014019
    step 3300,test_loss 41.017628
    step 3310,test_loss 41.039688
    step 3320,test_loss 41.036552
    step 3330,test_loss 41.041679
    step 3340,test_loss 41.010323
    step 3350,test_loss 41.019321
    step 3360,test_loss 41.003582
    step 3370,test_loss 41.039524
    step 3380,test_loss 41.041386
    step 3390,test_loss 41.014439
    step 3400,test_loss 41.031914
    step 3410,test_loss 41.047981
    step 3420,test_loss 41.020836
    step 3430,test_loss 41.035324
    step 3440,test_loss 41.021690
    step 3450,test_loss 41.026123
    step 3460,test_loss 41.029877
    step 3470,test_loss 41.027092
    step 3480,test_loss 41.027649
    step 3490,test_loss 41.023071
    step 3500,test_loss 41.027126
    step 3510,test_loss 41.018978
    step 3520,test_loss 41.030590
    step 3530,test_loss 41.026154
    step 3540,test_loss 41.021610
    step 3550,test_loss 41.014198
    step 3560,test_loss 41.032345
    step 3570,test_loss 41.030876
    step 3580,test_loss 41.013630
    step 3590,test_loss 41.025135
    step 3600,test_loss 41.035576
    step 3610,test_loss 41.018707
    step 3620,test_loss 41.019424
    step 3630,test_loss 41.028542
    step 3640,test_loss 41.039867
    step 3650,test_loss 41.014717
    step 3660,test_loss 41.035339
    step 3670,test_loss 41.031448
    step 3680,test_loss 41.016773
    step 3690,test_loss 41.025093
    step 3700,test_loss 41.030968
    step 3710,test_loss 41.027367
    step 3720,test_loss 41.039196
    step 3730,test_loss 41.024532
    step 3740,test_loss 41.039036
    step 3750,test_loss 41.003342
    step 3760,test_loss 41.035763
    step 3770,test_loss 41.035271
    step 3780,test_loss 41.009220
    step 3790,test_loss 41.030884
    step 3800,test_loss 41.029705
    step 3810,test_loss 41.029217
    step 3820,test_loss 41.028343
    step 3830,test_loss 41.020901
    step 3840,test_loss 41.039314
    step 3850,test_loss 41.045189
    step 3860,test_loss 41.028725
    step 3870,test_loss 41.026402
    step 3880,test_loss 41.014465
    step 3890,test_loss 41.027691
    step 3900,test_loss 41.027061
    step 3910,test_loss 41.023037
    step 3920,test_loss 41.028137
    step 3930,test_loss 41.035686
    step 3940,test_loss 41.021793
    step 3950,test_loss 41.014446
    step 3960,test_loss 41.018074
    step 3970,test_loss 41.037655
    step 3980,test_loss 41.019314
    step 3990,test_loss 41.022900
    step 4000,test_loss 41.026077
    step 4010,test_loss 41.035042
    step 4020,test_loss 41.022713
    step 4030,test_loss 41.029526
    step 4040,test_loss 41.026649
    step 4050,test_loss 41.033508
    step 4060,test_loss 41.028713
    step 4070,test_loss 41.031872
    step 4080,test_loss 41.017612
    step 4090,test_loss 41.031342
    step 4100,test_loss 41.024128
    step 4110,test_loss 41.021511
    step 4120,test_loss 41.028091
    step 4130,test_loss 41.025402
    step 4140,test_loss 41.028831
    step 4150,test_loss 41.025154
    step 4160,test_loss 41.028797
    step 4170,test_loss 41.023502
    step 4180,test_loss 41.023289
    step 4190,test_loss 41.026257
    step 4200,test_loss 41.023941
    step 4210,test_loss 41.017677
    step 4220,test_loss 41.018219
    step 4230,test_loss 41.021465
    step 4240,test_loss 41.022671
    step 4250,test_loss 41.035088
    step 4260,test_loss 41.028889
    step 4270,test_loss 41.015503
    step 4280,test_loss 41.011471
    step 4290,test_loss 41.034992
    step 4300,test_loss 41.024700
    step 4310,test_loss 41.021152
    step 4320,test_loss 41.033760
    step 4330,test_loss 41.022285
    step 4340,test_loss 41.023975
    step 4350,test_loss 41.047928
    step 4360,test_loss 41.040417
    step 4370,test_loss 41.015713
    step 4380,test_loss 41.021191
    step 4390,test_loss 41.028423
    step 4400,test_loss 41.046730
    step 4410,test_loss 41.019470
    step 4420,test_loss 41.023933
    step 4430,test_loss 41.023426
    step 4440,test_loss 41.044052
    step 4450,test_loss 41.023289
    step 4460,test_loss 41.037994
    step 4470,test_loss 41.027950
    step 4480,test_loss 41.018356
    step 4490,test_loss 41.026508
    step 4500,test_loss 41.024136
    step 4510,test_loss 41.032318
    step 4520,test_loss 41.028934
    step 4530,test_loss 41.027802
    step 4540,test_loss 41.034740
    step 4550,test_loss 41.018875
    step 4560,test_loss 41.009151
    step 4570,test_loss 41.028728
    step 4580,test_loss 41.013172
    step 4590,test_loss 41.023643
    step 4600,test_loss 41.036564
    step 4610,test_loss 41.023758
    step 4620,test_loss 41.010895
    step 4630,test_loss 41.016830
    step 4640,test_loss 41.025158
    step 4650,test_loss 41.031147
    step 4660,test_loss 41.030773
    step 4670,test_loss 41.014057
    step 4680,test_loss 41.012878
    step 4690,test_loss 41.020706
    step 4700,test_loss 41.024204
    step 4710,test_loss 41.030964
    step 4720,test_loss 41.042183
    step 4730,test_loss 41.004620
    step 4740,test_loss 41.043163
    step 4750,test_loss 41.026157
    step 4760,test_loss 41.016129
    step 4770,test_loss 41.028667
    step 4780,test_loss 41.033478
    step 4790,test_loss 41.032280
    step 4800,test_loss 41.029270
    step 4810,test_loss 41.032330
    step 4820,test_loss 41.026970
    step 4830,test_loss 41.034531
    step 4840,test_loss 41.038826
    step 4850,test_loss 41.033676
    step 4860,test_loss 41.037766
    step 4870,test_loss 41.026272
    step 4880,test_loss 41.024136
    step 4890,test_loss 41.020840
    step 4900,test_loss 41.028576
    step 4910,test_loss 41.013222
    step 4920,test_loss 41.042625
    step 4930,test_loss 41.035049
    step 4940,test_loss 41.023026
    step 4950,test_loss 41.023335
    step 4960,test_loss 41.028851
    step 4970,test_loss 41.024628
    step 4980,test_loss 41.019810
    step 4990,test_loss 41.026733
    Tensorflow R2:  0.999997486127
    Sklearn R2:  1.0
    
  • 相关阅读:
    Boost智能指针——shared_ptr
    Boost.asio的简单使用(timer,thread,io_service类)
    ACE线程管理机制
    利用boost::asio实现一个简单的服务器框架
    【转载】boost::lexical_cast 的使用
    BOOST 实用手册(摘录自校友博客)
    ACE的安装
    Microsoft SQL Server 2000 中的数据转换服务 (DTS)
    将 DTS 用于业务智能解决方案的最佳实践
    [转]理解“Future”
  • 原文地址:https://www.cnblogs.com/gnool/p/8197029.html
Copyright © 2011-2022 走看看