zoukankan      html  css  js  c++  java
  • RNN入门(4)利用LSTM实现整数加法运算

      本文将介绍LSTM模型在实现整数加法方面的应用。
      我们以0-255之间的整数加法为例,生成的结果在0到510之间。为了能利用深度学习模型模拟整数的加法运算,我们需要将输入的两个加数和输出的结果用二进制表示,这样就能得到向量,如加数在0-255内,可以用8位0-1向量来表示,前面的空位用0填充;结果在0-510内,可以用9位0-1向量来表示,前面的空位用0填充。因为两个加数均在0-255内变化,所以共有256*256=65536个输入向量以及65536个输出向量,输入向量为两个加数的二进制向量的拼接结果,因而是个16为的输入向量。用以下的Python代码可以模拟以上过程:

    import numpy as np
    
    # 最多8位二进制
    BINARY_DIM = 8
    
    # 将整数表示成为binary_dim位的二进制数,高位用0补齐
    def int_2_binary(number, binary_dim):
        binary_list = list(map(lambda x: int(x), bin(number)[2:]))
        number_dim = len(binary_list)
        result_list = [0]*(binary_dim-number_dim)+binary_list
        return result_list
    
    # 将一个二进制数组转为整数
    def binary2int(binary_array):
        out = 0
        for index, x in enumerate(reversed(binary_array)):
            out += x * pow(2, index)
        return out
    
    # 将[0,2**BINARY_DIM)所有数表示成二进制
    binary = np.array([int_2_binary(x, BINARY_DIM) for x in range(2**BINARY_DIM)])
    # print(binary)
    
    # 样本的输入向量和输出向量
    dataX = []
    dataY = []
    for i in range(binary.shape[0]):
        for j in range(binary.shape[0]):
            dataX.append(np.append(binary[i], binary[j]))
            dataY.append(int_2_binary(i+j, BINARY_DIM+1))
    
    # print(dataX)
    # print(dataY)
    
    # 重新特征X和目标变量Y数组,适应LSTM模型的输入和输出
    X = np.reshape(dataX, (len(dataX), 2*BINARY_DIM, 1))
    # print(X.shape)
    Y = np.array(dataY)
    # print(dataY.shape)
    

    在以上代码中,得到的dataX和dataY以满足要求,但为了能让LSTM模型处理,需要改变这两个数据集的形状。
      我们采用LSTM模型来训练上述数据,LSTM模型的结构很简单,就是简单的一层LSTM层,然后加上Dropout层,最后是全连接层,激活函数采用sigmoid函数,采用的损失函数为平均平方误差。整个结构的示意图如下:

    LSTM模型的结构示意图

    模型训练的代码如下:

    from keras.models import Sequential
    from keras.layers import Dense
    from keras.layers import Dropout
    from keras.layers import LSTM
    from keras import losses
    from keras.utils import plot_model
    
    # 定义LSTM模型
    model = Sequential()
    model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
    model.add(Dropout(0.2))
    model.add(Dense(Y.shape[1], activation='sigmoid'))
    model.compile(loss=losses.mean_squared_error, optimizer='adam')
    # print(model.summary())
    
    # plot model
    plot_model(model, to_file=r'./model.png', show_shapes=True)
    # train model
    epochs = 100
    model.fit(X, Y, epochs=epochs, batch_size=128)
    # save model
    mp = r'./LSTM_Operation.h5'
    model.save(mp)
    

    该LSTM模型每批训练128个样本,共训练100次,采用Adam优化器减少损失值。
      对这个模型进行训练,训练100次,损失值为0.0045。接下来我们就要用这个训练好的模型来预测。我们预测的方法为,虽然挑两个在0-255内的加数,转化为二进制向量作为输入向量,然后由LSTM模型输出结果,将该结果取整作为输出向量中的元素,最后将这个输出向量转化为整数,就是预测的两个加数的和。模型预测的代码如下:

    # use LSTM model to predict
    for _ in range(100):
        start = np.random.randint(0, len(dataX)-1)
        # print(dataX[start])
        number1 = dataX[start][0:BINARY_DIM]
        number2 = dataX[start][BINARY_DIM:]
        print('='*30)
        print('%s: %s'%(number1, binary2int(number1)))
        print('%s: %s'%(number2, binary2int(number2)))
        sample = np.reshape(X[start], (1, 2*BINARY_DIM, 1))
        predict = np.round(model.predict(sample), 0).astype(np.int32)[0]
        print('%s: %s'%(predict, binary2int(predict)))
    

    预测的100组样本的输出结果如下:

    ==============================
    [1 0 0 1 1 1 0 1]: 157
    [0 1 1 1 0 0 0 1]: 113
    [1 0 0 0 0 1 1 1 0]: 270
    ==============================
    [1 1 1 0 1 0 1 0]: 234
    [0 1 0 0 1 1 0 0]: 76
    [1 0 0 1 1 0 1 1 0]: 310
    ==============================
    [1 1 0 0 0 1 0 0]: 196
    [1 1 0 1 1 0 1 1]: 219
    [1 1 0 0 1 1 1 1 1]: 415
    ==============================
    [0 0 1 1 1 0 1 0]: 58
    [0 0 1 0 0 0 1 1]: 35
    [0 0 1 0 1 1 1 0 1]: 93
    ==============================
    [1 0 0 0 0 0 0 0]: 128
    [0 1 1 1 1 0 0 1]: 121
    [0 1 1 1 1 1 0 0 1]: 249
    ==============================
    [1 1 1 1 0 1 1 0]: 246
    [1 1 0 1 0 1 0 1]: 213
    [1 1 1 0 0 1 0 1 1]: 459
    ==============================
    [1 1 1 0 0 1 1 0]: 230
    [1 0 0 0 0 0 0 0]: 128
    [1 0 1 1 0 0 1 1 0]: 358
    ==============================
    [1 0 1 0 0 0 1 1]: 163
    [0 1 1 0 0 1 0 1]: 101
    [1 0 0 0 0 1 0 0 0]: 264
    ==============================
    [1 0 1 0 0 1 1 0]: 166
    [0 1 0 1 0 0 0 0]: 80
    [0 1 1 1 1 0 1 1 0]: 246
    ==============================
    [0 0 0 0 1 0 1 1]: 11
    [0 1 0 0 0 1 0 1]: 69
    [0 0 1 0 1 0 0 0 0]: 80
    ==============================
    [1 1 1 1 0 1 1 1]: 247
    [0 1 1 1 0 0 0 0]: 112
    [1 0 1 1 0 0 1 1 1]: 359
    ==============================
    [1 0 1 0 1 0 0 1]: 169
    [1 1 0 0 0 0 0 0]: 192
    [1 0 1 1 0 1 0 0 1]: 361
    ==============================
    [1 0 1 1 0 0 0 1]: 177
    [1 0 0 0 1 0 1 1]: 139
    [1 0 0 1 1 1 1 0 0]: 316
    ==============================
    [0 1 0 0 0 1 1 0]: 70
    [0 0 1 0 1 1 1 0]: 46
    [0 0 1 1 1 0 1 0 0]: 116
    ==============================
    [1 0 0 1 1 0 1 1]: 155
    [1 1 0 0 0 0 0 1]: 193
    [1 0 1 0 1 1 1 0 0]: 348
    ==============================
    [1 0 1 1 0 0 1 0]: 178
    [1 0 0 0 1 1 1 1]: 143
    [1 0 1 0 0 0 0 0 1]: 321
    ==============================
    [0 1 0 1 1 1 1 1]: 95
    [1 1 1 0 0 1 0 0]: 228
    [1 0 1 0 0 0 0 1 1]: 323
    ==============================
    [1 0 0 1 1 1 1 0]: 158
    [0 0 0 1 1 0 0 1]: 25
    [0 1 0 1 1 0 1 1 1]: 183
    ==============================
    [1 1 1 0 1 0 1 1]: 235
    [1 1 0 0 0 0 0 1]: 193
    [1 1 0 1 0 1 1 0 0]: 428
    ==============================
    [0 1 0 1 1 1 0 1]: 93
    [0 1 1 1 0 1 1 0]: 118
    [0 1 1 0 1 0 0 1 1]: 211
    ==============================
    [1 1 1 1 1 1 1 1]: 255
    [1 1 1 1 1 1 1 0]: 254
    [1 1 1 1 1 1 1 0 1]: 509
    ==============================
    [0 1 0 1 1 0 0 1]: 89
    [0 1 0 1 1 1 1 0]: 94
    [0 1 0 1 1 0 1 1 1]: 183
    ==============================
    [0 1 1 1 0 0 0 0]: 112
    [0 0 1 1 0 1 0 0]: 52
    [0 1 0 1 0 0 1 0 0]: 164
    ==============================
    [1 0 0 0 0 0 0 0]: 128
    [1 1 0 1 1 0 1 0]: 218
    [1 0 1 0 1 1 0 1 0]: 346
    ==============================
    [0 0 1 1 0 1 0 1]: 53
    [1 0 1 1 1 1 1 0]: 190
    [0 1 1 1 1 0 0 1 1]: 243
    ==============================
    [0 1 1 1 1 0 0 0]: 120
    [1 1 0 1 0 1 0 1]: 213
    [1 0 1 0 0 1 1 0 1]: 333
    ==============================
    [0 1 1 1 1 0 1 1]: 123
    [1 1 1 0 1 1 0 1]: 237
    [1 0 1 1 0 1 0 0 0]: 360
    ==============================
    [1 0 0 1 1 0 1 0]: 154
    [0 1 1 0 1 0 0 1]: 105
    [1 0 0 0 0 0 0 1 1]: 259
    ==============================
    [0 0 0 1 1 0 0 1]: 25
    [0 1 0 1 1 0 1 0]: 90
    [0 0 1 1 1 0 0 1 1]: 115
    ==============================
    [1 1 1 1 0 0 0 1]: 241
    [0 0 0 1 1 1 1 1]: 31
    [1 0 0 0 1 0 0 0 0]: 272
    ==============================
    [0 1 0 0 0 1 1 0]: 70
    [1 1 1 0 1 0 0 1]: 233
    [1 0 0 1 0 1 1 1 1]: 303
    ==============================
    [1 0 1 0 1 1 0 1]: 173
    [0 1 1 1 0 1 0 0]: 116
    [1 0 0 1 0 0 0 0 1]: 289
    ==============================
    [0 1 0 0 1 0 0 0]: 72
    [1 1 1 1 1 0 1 0]: 250
    [1 0 1 0 0 0 0 1 0]: 322
    ==============================
    [1 1 1 1 0 0 0 0]: 240
    [0 1 0 0 0 0 1 0]: 66
    [1 0 0 1 1 0 0 1 0]: 306
    ==============================
    [0 1 0 0 0 1 1 1]: 71
    [1 0 0 1 0 1 1 0]: 150
    [0 1 1 0 1 1 1 0 1]: 221
    ==============================
    [0 1 1 0 1 1 0 1]: 109
    [0 0 1 0 0 1 0 1]: 37
    [0 1 0 0 1 0 0 1 0]: 146
    ==============================
    [1 1 0 0 0 0 0 0]: 192
    [1 1 1 0 0 0 0 1]: 225
    [1 1 0 1 0 0 0 0 1]: 417
    ==============================
    [1 0 0 0 0 0 1 1]: 131
    [1 1 0 1 1 1 1 0]: 222
    [1 0 1 1 0 0 0 0 1]: 353
    ==============================
    [0 0 0 0 0 1 0 0]: 4
    [1 1 1 0 0 0 1 0]: 226
    [0 1 1 1 0 0 1 1 0]: 230
    ==============================
    [1 1 1 0 1 1 1 1]: 239
    [1 1 0 1 1 0 1 1]: 219
    [1 1 1 0 0 1 0 1 0]: 458
    ==============================
    [0 0 1 1 0 1 0 1]: 53
    [1 1 1 1 0 0 1 0]: 242
    [1 0 0 1 0 0 1 1 1]: 295
    ==============================
    [1 0 0 1 0 0 0 1]: 145
    [0 1 0 0 0 1 0 0]: 68
    [0 1 1 0 1 0 1 0 1]: 213
    ==============================
    [0 0 1 1 0 0 0 0]: 48
    [1 0 1 1 0 1 1 1]: 183
    [0 1 1 1 0 0 1 1 1]: 231
    ==============================
    [0 1 1 0 0 1 1 1]: 103
    [0 0 0 1 1 1 1 0]: 30
    [0 1 0 0 0 0 1 0 1]: 133
    ==============================
    [0 1 0 1 1 1 0 1]: 93
    [1 1 0 1 0 0 1 0]: 210
    [1 0 0 1 0 1 1 1 1]: 303
    ==============================
    [1 0 0 0 1 0 1 0]: 138
    [0 1 1 1 1 0 0 1]: 121
    [1 0 0 0 0 0 0 1 1]: 259
    ==============================
    [0 0 0 0 0 0 1 1]: 3
    [0 0 1 1 0 0 0 1]: 49
    [0 0 0 1 1 0 1 0 0]: 52
    ==============================
    [1 0 0 0 0 0 1 0]: 130
    [0 0 0 1 0 0 0 0]: 16
    [0 1 0 0 1 0 0 1 0]: 146
    ==============================
    [0 0 0 1 0 0 0 0]: 16
    [1 0 0 1 0 0 1 0]: 146
    [0 1 0 1 0 0 0 1 0]: 162
    ==============================
    [0 1 0 1 0 1 0 0]: 84
    [0 0 0 0 1 1 0 0]: 12
    [0 0 1 1 0 0 0 0 0]: 96
    ==============================
    [1 0 1 0 1 0 1 1]: 171
    [1 1 0 1 1 0 1 1]: 219
    [1 1 0 0 0 0 1 1 0]: 390
    ==============================
    [1 1 1 1 1 1 1 0]: 254
    [0 1 1 0 1 0 1 0]: 106
    [1 0 1 1 0 1 0 0 0]: 360
    ==============================
    [1 0 0 0 0 0 1 0]: 130
    [0 0 0 0 1 1 1 0]: 14
    [0 1 0 0 1 0 0 0 0]: 144
    ==============================
    [1 0 1 0 0 1 0 1]: 165
    [0 0 1 1 1 0 1 1]: 59
    [0 1 1 1 0 0 0 0 0]: 224
    ==============================
    [0 0 1 1 1 0 1 0]: 58
    [1 1 1 1 0 0 1 0]: 242
    [1 0 0 1 0 1 1 0 0]: 300
    ==============================
    [0 1 0 0 1 1 0 1]: 77
    [0 0 0 1 1 1 1 1]: 31
    [0 0 1 1 0 1 1 0 0]: 108
    ==============================
    [1 0 0 1 1 0 1 0]: 154
    [0 1 0 1 0 1 0 1]: 85
    [0 1 1 1 0 1 1 1 1]: 239
    ==============================
    [0 1 1 0 1 1 0 1]: 109
    [0 1 1 0 1 0 0 1]: 105
    [0 1 1 0 1 0 1 1 0]: 214
    ==============================
    [0 1 1 1 1 1 1 1]: 127
    [0 1 1 1 0 0 1 0]: 114
    [0 1 1 1 1 0 0 0 1]: 241
    ==============================
    [0 1 1 0 0 1 0 1]: 101
    [0 1 0 1 0 0 0 0]: 80
    [0 1 0 1 1 0 1 0 1]: 181
    ==============================
    [0 1 1 0 1 1 1 0]: 110
    [0 1 0 1 0 1 1 0]: 86
    [0 1 1 0 0 0 1 0 0]: 196
    ==============================
    [0 0 0 1 0 0 1 1]: 19
    [1 0 0 1 0 0 0 0]: 144
    [0 1 0 1 0 0 0 1 1]: 163
    ==============================
    [1 1 1 1 0 1 0 0]: 244
    [1 1 0 1 0 0 1 1]: 211
    [1 1 1 0 0 0 1 1 1]: 455
    ==============================
    [0 0 0 0 1 1 1 0]: 14
    [1 0 1 1 0 0 1 0]: 178
    [0 1 1 0 0 0 0 0 0]: 192
    ==============================
    [0 1 1 0 0 0 0 0]: 96
    [1 0 0 1 1 1 0 0]: 156
    [0 1 1 1 1 1 1 0 0]: 252
    ==============================
    [0 0 1 1 0 1 0 0]: 52
    [0 1 1 1 1 1 0 1]: 125
    [0 1 0 1 1 0 0 0 1]: 177
    ==============================
    [0 0 0 0 1 1 0 0]: 12
    [0 1 0 1 1 1 0 1]: 93
    [0 0 1 1 0 1 0 0 1]: 105
    ==============================
    [0 1 1 0 0 1 0 1]: 101
    [1 1 0 1 0 1 0 0]: 212
    [1 0 0 1 1 1 0 0 1]: 313
    ==============================
    [1 1 0 0 0 0 0 1]: 193
    [1 1 0 0 1 1 0 1]: 205
    [1 1 0 0 0 1 1 1 0]: 398
    ==============================
    [0 1 1 1 0 0 1 0]: 114
    [0 0 0 0 0 0 0 0]: 0
    [0 0 1 1 1 0 0 1 0]: 114
    ==============================
    [1 0 0 0 1 1 1 0]: 142
    [1 0 1 1 1 1 0 1]: 189
    [1 0 1 0 0 1 0 1 1]: 331
    ==============================
    [1 0 1 1 0 1 1 1]: 183
    [0 1 0 1 0 1 1 0]: 86
    [1 0 0 0 0 1 1 0 1]: 269
    ==============================
    [1 0 1 0 0 0 1 1]: 163
    [1 1 1 0 0 1 0 1]: 229
    [1 1 0 0 0 1 0 0 0]: 392
    ==============================
    [0 0 1 1 0 0 0 1]: 49
    [1 1 1 0 0 1 1 1]: 231
    [1 0 0 0 1 1 0 0 0]: 280
    ==============================
    [1 0 0 0 1 1 1 1]: 143
    [1 0 1 0 1 0 0 0]: 168
    [1 0 0 1 1 0 1 1 1]: 311
    ==============================
    [0 1 0 0 0 0 0 0]: 64
    [0 0 0 0 0 1 0 1]: 5
    [0 0 1 0 0 0 1 0 1]: 69
    ==============================
    [1 1 1 1 1 0 1 1]: 251
    [1 0 1 1 1 0 0 1]: 185
    [1 1 0 1 1 0 1 0 0]: 436
    ==============================
    [1 1 1 0 1 1 1 0]: 238
    [1 1 0 0 0 0 1 0]: 194
    [1 1 0 1 1 0 0 0 0]: 432
    ==============================
    [0 0 1 1 1 1 0 0]: 60
    [0 0 0 1 0 1 1 1]: 23
    [0 0 1 0 1 0 0 1 1]: 83
    ==============================
    [0 1 1 1 0 1 0 0]: 116
    [1 1 1 1 1 1 0 0]: 252
    [1 0 1 1 1 0 0 0 0]: 368
    ==============================
    [1 1 0 1 0 1 1 0]: 214
    [1 1 1 1 0 1 0 0]: 244
    [1 1 1 0 0 1 0 1 0]: 458
    ==============================
    [1 1 1 1 1 1 1 0]: 254
    [1 1 0 1 0 0 0 1]: 209
    [1 1 1 0 0 1 1 1 1]: 463
    ==============================
    [0 0 0 0 0 0 1 0]: 2
    [0 0 0 0 1 1 0 1]: 13
    [0 0 0 0 0 1 1 1 1]: 15
    ==============================
    [0 1 1 0 0 1 1 1]: 103
    [1 0 1 1 1 1 1 0]: 190
    [1 0 0 1 0 0 1 0 1]: 293
    ==============================
    [1 1 1 1 0 1 1 0]: 246
    [0 1 0 1 0 0 1 0]: 82
    [1 0 1 0 0 1 0 0 0]: 328
    ==============================
    [0 1 1 1 0 0 1 1]: 115
    [0 0 1 1 1 0 1 1]: 59
    [0 1 0 1 0 1 1 1 0]: 174
    ==============================
    [0 1 0 1 1 0 0 1]: 89
    [0 1 1 0 1 0 1 1]: 107
    [0 1 1 0 0 0 1 0 0]: 196
    ==============================
    [0 1 0 0 0 1 0 0]: 68
    [0 0 1 1 1 0 0 0]: 56
    [0 0 1 1 1 1 1 0 0]: 124
    ==============================
    [1 1 0 0 1 0 0 0]: 200
    [1 0 1 0 0 0 1 0]: 162
    [1 0 1 1 0 1 0 1 0]: 362
    ==============================
    [1 1 1 1 0 0 1 1]: 243
    [0 1 1 0 0 0 1 1]: 99
    [1 0 1 0 1 0 1 1 0]: 342
    ==============================
    [0 0 1 0 1 0 0 1]: 41
    [0 1 0 0 1 0 0 1]: 73
    [0 0 1 1 1 0 0 1 0]: 114
    ==============================
    [0 0 0 1 1 1 0 1]: 29
    [1 0 1 0 1 1 1 0]: 174
    [0 1 1 0 0 1 0 1 1]: 203
    ==============================
    [0 0 0 0 1 1 1 1]: 15
    [0 0 1 1 1 1 0 1]: 61
    [0 0 1 0 0 1 1 0 0]: 76
    ==============================
    [1 1 1 1 1 0 1 1]: 251
    [1 1 0 1 0 0 0 0]: 208
    [1 1 1 0 0 1 0 1 1]: 459
    ==============================
    [1 1 1 0 1 0 0 0]: 232
    [0 1 1 0 0 0 1 0]: 98
    [1 0 1 0 0 1 0 1 0]: 330
    ==============================
    [1 0 1 1 0 1 0 0]: 180
    [0 1 0 1 0 1 1 1]: 87
    [1 0 0 0 0 1 0 1 1]: 267
    ==============================
    [1 0 0 0 0 1 1 0]: 134
    [1 0 0 1 0 1 0 1]: 149
    [1 0 0 0 1 1 0 1 1]: 283
    ==============================
    [1 0 1 0 1 1 0 1]: 173
    [0 1 1 1 1 1 0 0]: 124
    [1 0 0 1 0 1 0 0 1]: 297
    ==============================
    [0 1 0 0 1 0 0 0]: 72
    [0 1 1 0 0 0 1 1]: 99
    [0 1 0 1 0 1 0 1 1]: 171
    ==============================
    [1 1 0 1 0 1 0 1]: 213
    [0 0 0 1 1 1 1 0]: 30
    [0 1 1 1 1 0 0 1 1]: 243
    

      可以看到,这个简单的LSTM模型的预测的结果全部正确。因此,这就可以用来模拟0-255内的整数的加法运算,是不是很神奇呢?
      如果需要想将加数的范围扩大,只需要改变代码中的BINARY_DIM变量即可。但是,加数的范围越大,样本就越大,如2**10=1024内的加法,就会有1024*1024=1048576个样本,这样大的样本量的无疑需要更多的训练时间。
      本文到此结束,感谢阅读~如果不当之处,请速联系笔者,欢迎大家交流~祝您好运~

    注意:本人现已开通微信公众号: Python爬虫与算法(微信号为:easy_web_scrape), 欢迎大家关注哦~~

    完整的Python代码如下:

    import numpy as np
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.layers import Dropout
    from keras.layers import LSTM
    from keras import losses
    from keras.utils import plot_model
    
    # 最多8位二进制
    BINARY_DIM = 8
    
    # 将整数表示成为binary_dim位的二进制数,高位用0补齐
    def int_2_binary(number, binary_dim):
        binary_list = list(map(lambda x: int(x), bin(number)[2:]))
        number_dim = len(binary_list)
        result_list = [0]*(binary_dim-number_dim)+binary_list
        return result_list
    
    # 将一个二进制数组转为整数
    def binary2int(binary_array):
        out = 0
        for index, x in enumerate(reversed(binary_array)):
            out += x * pow(2, index)
        return out
    
    # 将[0,2**BINARY_DIM)所有数表示成二进制
    binary = np.array([int_2_binary(x, BINARY_DIM) for x in range(2**BINARY_DIM)])
    # print(binary)
    
    # 样本的输入向量和输出向量
    dataX = []
    dataY = []
    for i in range(binary.shape[0]):
        for j in range(binary.shape[0]):
            dataX.append(np.append(binary[i], binary[j]))
            dataY.append(int_2_binary(i+j, BINARY_DIM+1))
    
    # print(dataX)
    # print(dataY)
    
    # 重新特征X和目标变量Y数组,适应LSTM模型的输入和输出
    X = np.reshape(dataX, (len(dataX), 2*BINARY_DIM, 1))
    # print(X.shape)
    Y = np.array(dataY)
    # print(dataY.shape)
    
    # 定义LSTM模型
    model = Sequential()
    model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
    model.add(Dropout(0.2))
    model.add(Dense(Y.shape[1], activation='sigmoid'))
    model.compile(loss=losses.mean_squared_error, optimizer='adam')
    # print(model.summary())
    
    # plot model
    plot_model(model, to_file=r'./model.png', show_shapes=True)
    # train model
    epochs = 100
    model.fit(X, Y, epochs=epochs, batch_size=128)
    # save model
    mp = r'./LSTM_Operation.h5'
    model.save(mp)
    
    # use LSTM model to predict
    for _ in range(100):
        start = np.random.randint(0, len(dataX)-1)
        # print(dataX[start])
        number1 = dataX[start][0:BINARY_DIM]
        number2 = dataX[start][BINARY_DIM:]
        print('='*30)
        print('%s: %s'%(number1, binary2int(number1)))
        print('%s: %s'%(number2, binary2int(number2)))
        sample = np.reshape(X[start], (1, 2*BINARY_DIM, 1))
        predict = np.round(model.predict(sample), 0).astype(np.int32)[0]
        print('%s: %s'%(predict, binary2int(predict)))
    
  • 相关阅读:
    C语言学习笔记-静态库、动态库的制作和使用
    各种消息队列的对比
    如何使用Jupyter notebook
    Ubuntu16.04配置OpenCV环境
    Docker容器发展历史
    Ubuntu OpenSSH Server
    SpringBoot 系列文章
    SpringBoot 模板 Thymeleaf 的使用
    18、spring注解学习(AOP)——AOP功能测试
    17、spring注解学习(自动装配)——@Profile根据当前环境,动态的激活和切换一系列组件的功能
  • 原文地址:https://www.cnblogs.com/jclian91/p/9867229.html
Copyright © 2011-2022 走看看