zoukankan      html  css  js  c++  java
  • tencent_2.1_linear_regression_model

     linear_regression_model.py

    import tensorflow as tf
    import numpy as np
    
    class linearRegressionModel:
    
        def __init__(self, x_dimen):
            self.x_dimen = x_dimen
            self._index_in_epoch = 0
            self.constructModel()
            self.sess = tf.Session()
            self.sess.run(tf.global_variables_initializer())
    
        def weight_variable(self, shape):
            initial = tf.truncated_normal(shape, stddev = 0.1)
            return tf.Variable(initial)
    
        def bias_variable(self, shape):
            initial = tf.constant(0.1, shape = shape)
            return tf.Variable(initial)
    
        def constructModel(self):
            self.x = tf.placeholder(tf.float32, [None, self.x_dimen])
            self.y = tf.placeholder(tf.float32, [None, 1])
            self.w = self.weight_variable([self.x_dimen, 1])
            self.b = self.bias_variable([1])
            self.y_prec = tf.nn.bias_add(tf.matmul(self.x, self.w), self.b)
    
            mse = tf.reduce_mean(tf.squared_difference(self.y_prec, self.y))
            l2 = tf.reduce_mean(tf.square(self.w))
            self.loss = mse + 0.15*l2
            self.train_step = tf.train.AdamOptimizer(0.1).minimize(self.loss)
    
        def next_batch(self, batch_size):
            start = self._index_in_epoch
            self._index_in_epoch += batch_size
            if self._index_in_epoch > self._num_datas:
                perm = np.arange(self._num_datas)
                np.random.shuffle(perm)
                self._datas = self._datas[perm]
                self._labels = self._labels[perm]
                start = 0
                self._index_in_epoch = batch_size
                assert batch_size < self._num_datas
            end = self._index_in_epoch
            return self._datas[start:end], self._labels[start:end]
    
        def train(self, x_train, y_train):
            self._datas = x_train
            self._labels = y_train
            self._num_datas = x_train.shape[0]
            for i in range(5000):
                batch = self.next_batch(100)
                self.sess.run(self.train_step, feed_dict={self.x:batch[0],self.y:batch[1]})
                if i%10 == 0:
                    train_loss = self.sess.run(self.loss, feed_dict={self.x:batch[0],self.y:batch[1]})
                    print("step %d,test_loss %f" % (i, train_loss))
    
        def predict_batch(self,arr,batch_size):
            for i in range(0,len(arr),batch_size):
                yield arr[i:i+batch_size]
    
        def predict(self, x_predict):
            pred_list = []
            for x_test_batch in self.predict_batch(x_predict, 100):
                pred = self.sess.run(self.y_prec, feed_dict={self.x:x_test_batch})
                pred_list.append(pred)
            return np.vstack(pred_list)

    run.py

    from sklearn.model_selection import train_test_split
    from sklearn.metrics import r2_score
    from sklearn.datasets import make_regression
    from sklearn.linear_model import LinearRegression
    from linear_regression_model import linearRegressionModel as lrm
    
    if __name__ == '__main__':
        x, y = make_regression(7000)
        x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5)
        y_lrm_train = y_train.reshape(-1, 1)
        y_lrm_test = y_test.reshape(-1, 1)
    
        linear = lrm(x.shape[1])
        linear.train(x_train, y_lrm_train)
        y_predict = linear.predict(x_test)
        print("Tensorflow R2: ", r2_score(y_predict.ravel(), y_lrm_test.ravel()))
    
        lr = LinearRegression()
        y_predict = lr.fit(x_train, y_train).predict(x_test)
        print("Sklearn R2: ", r2_score(y_predict, y_test))

    结果:

    ubuntu@VM-12-146-ubuntu:~$ python run.py
    step 0,test_loss 50554.742188
    step 10,test_loss 53487.046875
    step 20,test_loss 36099.449219
    step 30,test_loss 50567.339844
    step 40,test_loss 45398.449219
    step 50,test_loss 40298.109375
    step 60,test_loss 40552.335938
    step 70,test_loss 33812.503906
    step 80,test_loss 39265.847656
    step 90,test_loss 36639.019531
    step 100,test_loss 38088.800781
    step 110,test_loss 34145.976562
    step 120,test_loss 34928.343750
    step 130,test_loss 27798.576172
    step 140,test_loss 27276.896484
    step 150,test_loss 24163.076172
    step 160,test_loss 27816.298828
    step 170,test_loss 27190.076172
    step 180,test_loss 23125.751953
    step 190,test_loss 22233.732422
    step 200,test_loss 25407.130859
    step 210,test_loss 24661.605469
    step 220,test_loss 26708.384766
    step 230,test_loss 19841.687500
    step 240,test_loss 24072.408203
    step 250,test_loss 17266.611328
    step 260,test_loss 26461.554688
    step 270,test_loss 22570.013672
    step 280,test_loss 24905.404297
    step 290,test_loss 15179.116211
    step 300,test_loss 18661.962891
    step 310,test_loss 19683.906250
    step 320,test_loss 15327.764648
    step 330,test_loss 13168.517578
    step 340,test_loss 18685.867188
    step 350,test_loss 15487.893555
    step 360,test_loss 14875.776367
    step 370,test_loss 13526.083008
    step 380,test_loss 12865.832031
    step 390,test_loss 15644.985352
    step 400,test_loss 15185.415039
    step 410,test_loss 14463.912109
    step 420,test_loss 15174.072266
    step 430,test_loss 12386.041016
    step 440,test_loss 11713.977539
    step 450,test_loss 13448.644531
    step 460,test_loss 11875.729492
    step 470,test_loss 10373.699219
    step 480,test_loss 10216.433594
    step 490,test_loss 10364.944336
    step 500,test_loss 11086.302734
    step 510,test_loss 5693.416992
    step 520,test_loss 8974.305664
    step 530,test_loss 10542.807617
    step 540,test_loss 8704.875977
    step 550,test_loss 7724.854980
    step 560,test_loss 10312.858398
    step 570,test_loss 8256.900391
    step 580,test_loss 7817.630859
    step 590,test_loss 5888.268555
    step 600,test_loss 6200.811523
    step 610,test_loss 8451.625000
    step 620,test_loss 5280.057617
    step 630,test_loss 6405.470215
    step 640,test_loss 7642.217773
    step 650,test_loss 7243.946777
    step 660,test_loss 5507.767090
    step 670,test_loss 5810.758789
    step 680,test_loss 5605.552246
    step 690,test_loss 4809.159180
    step 700,test_loss 5075.568848
    step 710,test_loss 6186.606445
    step 720,test_loss 4644.177734
    step 730,test_loss 3405.376953
    step 740,test_loss 3963.750732
    step 750,test_loss 4160.150879
    step 760,test_loss 3322.133301
    step 770,test_loss 4439.647461
    step 780,test_loss 2879.551758
    step 790,test_loss 4686.266602
    step 800,test_loss 3389.883789
    step 810,test_loss 2878.820801
    step 820,test_loss 2689.784668
    step 830,test_loss 3682.542969
    step 840,test_loss 4084.625977
    step 850,test_loss 2865.651611
    step 860,test_loss 2581.408936
    step 870,test_loss 2442.209717
    step 880,test_loss 2742.542969
    step 890,test_loss 2850.451172
    step 900,test_loss 2632.563477
    step 910,test_loss 2156.909180
    step 920,test_loss 2305.271729
    step 930,test_loss 1993.062134
    step 940,test_loss 2311.316162
    step 950,test_loss 1890.035156
    step 960,test_loss 1763.053955
    step 970,test_loss 1725.859131
    step 980,test_loss 1759.168091
    step 990,test_loss 1323.621216
    step 1000,test_loss 1866.871338
    step 1010,test_loss 1352.615479
    step 1020,test_loss 1616.773560
    step 1030,test_loss 1143.031250
    step 1040,test_loss 1489.623291
    step 1050,test_loss 1476.880371
    step 1060,test_loss 1370.258667
    step 1070,test_loss 1141.999878
    step 1080,test_loss 1166.933228
    step 1090,test_loss 1108.787476
    step 1100,test_loss 960.834412
    step 1110,test_loss 1265.723999
    step 1120,test_loss 845.209717
    step 1130,test_loss 1160.951294
    step 1140,test_loss 1252.884766
    step 1150,test_loss 947.760193
    step 1160,test_loss 949.110107
    step 1170,test_loss 703.865845
    step 1180,test_loss 785.185425
    step 1190,test_loss 909.500916
    step 1200,test_loss 739.722473
    step 1210,test_loss 752.640686
    step 1220,test_loss 621.177429
    step 1230,test_loss 664.028809
    step 1240,test_loss 842.765198
    step 1250,test_loss 537.212891
    step 1260,test_loss 612.894226
    step 1270,test_loss 515.576599
    step 1280,test_loss 595.646362
    step 1290,test_loss 591.783936
    step 1300,test_loss 636.358215
    step 1310,test_loss 421.886414
    step 1320,test_loss 454.036652
    step 1330,test_loss 441.143738
    step 1340,test_loss 393.075867
    step 1350,test_loss 415.702820
    step 1360,test_loss 448.445557
    step 1370,test_loss 350.026550
    step 1380,test_loss 391.277832
    step 1390,test_loss 458.568481
    step 1400,test_loss 423.181671
    step 1410,test_loss 410.131195
    step 1420,test_loss 336.751404
    step 1430,test_loss 279.585114
    step 1440,test_loss 244.326080
    step 1450,test_loss 322.283997
    step 1460,test_loss 275.522095
    step 1470,test_loss 240.525375
    step 1480,test_loss 272.072205
    step 1490,test_loss 202.847000
    step 1500,test_loss 276.935272
    step 1510,test_loss 211.451935
    step 1520,test_loss 235.174835
    step 1530,test_loss 201.162201
    step 1540,test_loss 190.168976
    step 1550,test_loss 184.773560
    step 1560,test_loss 214.716644
    step 1570,test_loss 171.767059
    step 1580,test_loss 174.465897
    step 1590,test_loss 167.292999
    step 1600,test_loss 169.164307
    step 1610,test_loss 169.313507
    step 1620,test_loss 172.326813
    step 1630,test_loss 157.986420
    step 1640,test_loss 146.000031
    step 1650,test_loss 134.339294
    step 1660,test_loss 134.111725
    step 1670,test_loss 129.241135
    step 1680,test_loss 133.466339
    step 1690,test_loss 128.226562
    step 1700,test_loss 125.525810
    step 1710,test_loss 115.412491
    step 1720,test_loss 111.956238
    step 1730,test_loss 114.153198
    step 1740,test_loss 115.078041
    step 1750,test_loss 107.463058
    step 1760,test_loss 100.881134
    step 1770,test_loss 108.238266
    step 1780,test_loss 105.143570
    step 1790,test_loss 105.204819
    step 1800,test_loss 92.495026
    step 1810,test_loss 96.282784
    step 1820,test_loss 94.634003
    step 1830,test_loss 88.530350
    step 1840,test_loss 94.419586
    step 1850,test_loss 90.118423
    step 1860,test_loss 90.100471
    step 1870,test_loss 92.007225
    step 1880,test_loss 83.837082
    step 1890,test_loss 88.835732
    step 1900,test_loss 84.472664
    step 1910,test_loss 84.264526
    step 1920,test_loss 83.923073
    step 1930,test_loss 81.012611
    step 1940,test_loss 77.510666
    step 1950,test_loss 81.193970
    step 1960,test_loss 81.355026
    step 1970,test_loss 79.240715
    step 1980,test_loss 75.555099
    step 1990,test_loss 76.817108
    step 2000,test_loss 77.046547
    step 2010,test_loss 75.549843
    step 2020,test_loss 77.543091
    step 2030,test_loss 74.993935
    step 2040,test_loss 75.710793
    step 2050,test_loss 74.185966
    step 2060,test_loss 77.117058
    step 2070,test_loss 74.278481
    step 2080,test_loss 73.415520
    step 2090,test_loss 71.878136
    step 2100,test_loss 73.699829
    step 2110,test_loss 73.163803
    step 2120,test_loss 72.371719
    step 2130,test_loss 73.948982
    step 2140,test_loss 72.047447
    step 2150,test_loss 71.245361
    step 2160,test_loss 72.283531
    step 2170,test_loss 71.045090
    step 2180,test_loss 70.661865
    step 2190,test_loss 70.973740
    step 2200,test_loss 71.079689
    step 2210,test_loss 71.110039
    step 2220,test_loss 70.308311
    step 2230,test_loss 69.957397
    step 2240,test_loss 70.038406
    step 2250,test_loss 70.325066
    step 2260,test_loss 70.040375
    step 2270,test_loss 70.044968
    step 2280,test_loss 69.895622
    step 2290,test_loss 69.759949
    step 2300,test_loss 69.858398
    step 2310,test_loss 69.752213
    step 2320,test_loss 69.294426
    step 2330,test_loss 69.490250
    step 2340,test_loss 69.516228
    step 2350,test_loss 69.264801
    step 2360,test_loss 69.474541
    step 2370,test_loss 69.354004
    step 2380,test_loss 69.242920
    step 2390,test_loss 69.212044
    step 2400,test_loss 69.350334
    step 2410,test_loss 69.142731
    step 2420,test_loss 69.137589
    step 2430,test_loss 69.149040
    step 2440,test_loss 69.254677
    step 2450,test_loss 69.260796
    step 2460,test_loss 69.097015
    step 2470,test_loss 69.137276
    step 2480,test_loss 69.074081
    step 2490,test_loss 69.078163
    step 2500,test_loss 69.126526
    step 2510,test_loss 69.039513
    step 2520,test_loss 68.991615
    step 2530,test_loss 68.940292
    step 2540,test_loss 68.944313
    step 2550,test_loss 68.915138
    step 2560,test_loss 68.998390
    step 2570,test_loss 68.909355
    step 2580,test_loss 68.906288
    step 2590,test_loss 68.861900
    step 2600,test_loss 68.936501
    step 2610,test_loss 68.967476
    step 2620,test_loss 68.890762
    step 2630,test_loss 68.898155
    step 2640,test_loss 68.884277
    step 2650,test_loss 68.880043
    step 2660,test_loss 68.900604
    step 2670,test_loss 68.914810
    step 2680,test_loss 68.907867
    step 2690,test_loss 68.896652
    step 2700,test_loss 68.859848
    step 2710,test_loss 68.878159
    step 2720,test_loss 68.902023
    step 2730,test_loss 68.873718
    step 2740,test_loss 68.916138
    step 2750,test_loss 68.877785
    step 2760,test_loss 68.864380
    step 2770,test_loss 68.869530
    step 2780,test_loss 68.861702
    step 2790,test_loss 68.882233
    step 2800,test_loss 68.869751
    step 2810,test_loss 68.869904
    step 2820,test_loss 68.888641
    step 2830,test_loss 68.850166
    step 2840,test_loss 68.856880
    step 2850,test_loss 68.889809
    step 2860,test_loss 68.859596
    step 2870,test_loss 68.855354
    step 2880,test_loss 68.882050
    step 2890,test_loss 68.885651
    step 2900,test_loss 68.859543
    step 2910,test_loss 68.818733
    step 2920,test_loss 68.832703
    step 2930,test_loss 68.852486
    step 2940,test_loss 68.870384
    step 2950,test_loss 68.855949
    step 2960,test_loss 68.841423
    step 2970,test_loss 68.872490
    step 2980,test_loss 68.883705
    step 2990,test_loss 68.855988
    step 3000,test_loss 68.870636
    step 3010,test_loss 68.862991
    step 3020,test_loss 68.870255
    step 3030,test_loss 68.863777
    step 3040,test_loss 68.900009
    step 3050,test_loss 68.891266
    step 3060,test_loss 68.879257
    step 3070,test_loss 68.858406
    step 3080,test_loss 68.860107
    step 3090,test_loss 68.874466
    step 3100,test_loss 68.860497
    step 3110,test_loss 68.858696
    step 3120,test_loss 68.883636
    step 3130,test_loss 68.876640
    step 3140,test_loss 68.868874
    step 3150,test_loss 68.874413
    step 3160,test_loss 68.827744
    step 3170,test_loss 68.875488
    step 3180,test_loss 68.881233
    step 3190,test_loss 68.859863
    step 3200,test_loss 68.885933
    step 3210,test_loss 68.869545
    step 3220,test_loss 68.842995
    step 3230,test_loss 68.863731
    step 3240,test_loss 68.871620
    step 3250,test_loss 68.887390
    step 3260,test_loss 68.847641
    step 3270,test_loss 68.856567
    step 3280,test_loss 68.860397
    step 3290,test_loss 68.862595
    step 3300,test_loss 68.884712
    step 3310,test_loss 68.877060
    step 3320,test_loss 68.854836
    step 3330,test_loss 68.852036
    step 3340,test_loss 68.841431
    step 3350,test_loss 68.856216
    step 3360,test_loss 68.901741
    step 3370,test_loss 68.879402
    step 3380,test_loss 68.848961
    step 3390,test_loss 68.894516
    step 3400,test_loss 68.845978
    step 3410,test_loss 68.843407
    step 3420,test_loss 68.880623
    step 3430,test_loss 68.874138
    step 3440,test_loss 68.875328
    step 3450,test_loss 68.847839
    step 3460,test_loss 68.856407
    step 3470,test_loss 68.871277
    step 3480,test_loss 68.878189
    step 3490,test_loss 68.884209
    step 3500,test_loss 68.842300
    step 3510,test_loss 68.860901
    step 3520,test_loss 68.872070
    step 3530,test_loss 68.859085
    step 3540,test_loss 68.865288
    step 3550,test_loss 68.876419
    step 3560,test_loss 68.869156
    step 3570,test_loss 68.886154
    step 3580,test_loss 68.872124
    step 3590,test_loss 68.894897
    step 3600,test_loss 68.877914
    step 3610,test_loss 68.880173
    step 3620,test_loss 68.901749
    step 3630,test_loss 68.838867
    step 3640,test_loss 68.871284
    step 3650,test_loss 68.844742
    step 3660,test_loss 68.849792
    step 3670,test_loss 68.876732
    step 3680,test_loss 68.837036
    step 3690,test_loss 68.862450
    step 3700,test_loss 68.875610
    step 3710,test_loss 68.856522
    step 3720,test_loss 68.859993
    step 3730,test_loss 68.869888
    step 3740,test_loss 68.853325
    step 3750,test_loss 68.863533
    step 3760,test_loss 68.858566
    step 3770,test_loss 68.843681
    step 3780,test_loss 68.858315
    step 3790,test_loss 68.876717
    step 3800,test_loss 68.875565
    step 3810,test_loss 68.901901
    step 3820,test_loss 68.881500
    step 3830,test_loss 68.873932
    step 3840,test_loss 68.833977
    step 3850,test_loss 68.854820
    step 3860,test_loss 68.882072
    step 3870,test_loss 68.848206
    step 3880,test_loss 68.874969
    step 3890,test_loss 68.866570
    step 3900,test_loss 68.849907
    step 3910,test_loss 68.879288
    step 3920,test_loss 68.866447
    step 3930,test_loss 68.845490
    step 3940,test_loss 68.854675
    step 3950,test_loss 68.858597
    step 3960,test_loss 68.850449
    step 3970,test_loss 68.853188
    step 3980,test_loss 68.881699
    step 3990,test_loss 68.852928
    step 4000,test_loss 68.903976
    step 4010,test_loss 68.883759
    step 4020,test_loss 68.873917
    step 4030,test_loss 68.858192
    step 4040,test_loss 68.883629
    step 4050,test_loss 68.888466
    step 4060,test_loss 68.867790
    step 4070,test_loss 68.853691
    step 4080,test_loss 68.888000
    step 4090,test_loss 68.863617
    step 4100,test_loss 68.844299
    step 4110,test_loss 68.887550
    step 4120,test_loss 68.867096
    step 4130,test_loss 68.840004
    step 4140,test_loss 68.872429
    step 4150,test_loss 68.880051
    step 4160,test_loss 68.840759
    step 4170,test_loss 68.876610
    step 4180,test_loss 68.866318
    step 4190,test_loss 68.864395
    step 4200,test_loss 68.866844
    step 4210,test_loss 68.870064
    step 4220,test_loss 68.835907
    step 4230,test_loss 68.853455
    step 4240,test_loss 68.866196
    step 4250,test_loss 68.892769
    step 4260,test_loss 68.853798
    step 4270,test_loss 68.830009
    step 4280,test_loss 68.863434
    step 4290,test_loss 68.843948
    step 4300,test_loss 68.881226
    step 4310,test_loss 68.859001
    step 4320,test_loss 68.858551
    step 4330,test_loss 68.865173
    step 4340,test_loss 68.851601
    step 4350,test_loss 68.886436
    step 4360,test_loss 68.861244
    step 4370,test_loss 68.834076
    step 4380,test_loss 68.848862
    step 4390,test_loss 68.852104
    step 4400,test_loss 68.858315
    step 4410,test_loss 68.841331
    step 4420,test_loss 68.855446
    step 4430,test_loss 68.870331
    step 4440,test_loss 68.888527
    step 4450,test_loss 68.869141
    step 4460,test_loss 68.870499
    step 4470,test_loss 68.873802
    step 4480,test_loss 68.856995
    step 4490,test_loss 68.861732
    step 4500,test_loss 68.848595
    step 4510,test_loss 68.877945
    step 4520,test_loss 68.862206
    step 4530,test_loss 68.872231
    step 4540,test_loss 68.858627
    step 4550,test_loss 68.879288
    step 4560,test_loss 68.857910
    step 4570,test_loss 68.884712
    step 4580,test_loss 68.886894
    step 4590,test_loss 68.882500
    step 4600,test_loss 68.884560
    step 4610,test_loss 68.845879
    step 4620,test_loss 68.844276
    step 4630,test_loss 68.860344
    step 4640,test_loss 68.869537
    step 4650,test_loss 68.862511
    step 4660,test_loss 68.886688
    step 4670,test_loss 68.858307
    step 4680,test_loss 68.870544
    step 4690,test_loss 68.895294
    step 4700,test_loss 68.836151
    step 4710,test_loss 68.854034
    step 4720,test_loss 68.896736
    step 4730,test_loss 68.873703
    step 4740,test_loss 68.850273
    step 4750,test_loss 68.877182
    step 4760,test_loss 68.850876
    step 4770,test_loss 68.841743
    step 4780,test_loss 68.863495
    step 4790,test_loss 68.880394
    step 4800,test_loss 68.865532
    step 4810,test_loss 68.880829
    step 4820,test_loss 68.898735
    step 4830,test_loss 68.868607
    step 4840,test_loss 68.851143
    step 4850,test_loss 68.888962
    step 4860,test_loss 68.861221
    step 4870,test_loss 68.856529
    step 4880,test_loss 68.870476
    step 4890,test_loss 68.859734
    step 4900,test_loss 68.859306
    step 4910,test_loss 68.870979
    step 4920,test_loss 68.879097
    step 4930,test_loss 68.886711
    step 4940,test_loss 68.869553
    step 4950,test_loss 68.861496
    step 4960,test_loss 68.846848
    step 4970,test_loss 68.867828
    step 4980,test_loss 68.883766
    step 4990,test_loss 68.881592
    ('Tensorflow R2: ', 0.99999761462739589)
    ('Sklearn R2: ', 1.0)
    ubuntu@VM-12-146-ubuntu:~$ 
    View Code
  • 相关阅读:
    三、Pandas入门
    二、NumPy入门
    jQuery模拟angular的数据绑定
    ajax里的getJSON的用法
    SQL中关于传递参数为Null的示例
    原生ajax示例
    页面自增加示例
    angular1数据绑定例子
    angular2 工程目录结构介绍
    angular js环境配置
  • 原文地址:https://www.cnblogs.com/exciting/p/11322338.html
Copyright © 2011-2022 走看看