zoukankan      html  css  js  c++  java
  • 序贯模型

     

    模型搭建

    举一个最简单的MLP例子,这下面我们添加的都是全连接层

    from keras.models import Sequential

    from keras.layers import Dense, Activation

    model = Sequential()  #序贯模型

    model.add(Dense(units=64, input_dim=100))

    model.add(Activation("relu"))     

    model.add(Dense(units=10))

    model.add(Activation("softmax"))

    #或者使用一次性搭建的方式

    model = Sequential([Dense(32, units=784),Activation('relu'),Dense(10),Activation('softmax')])

    model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])   #通过compile来编译模型

    from keras.optimizers import SGD  #定制损失函数。Keras里也封装好了很多优化器和损失函数

    model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.01, momentum=0.9, nesterov=True))

    输入数据并训练

    注意:batch_size太大可能不能收敛到最低点,batch_size太小测试的准确率会剧烈震荡

    1)model.fit(x_train, y_train, epochs=5, batch_size=32)

    2)model.train_on_batch(x_batch, y_batch)  #自己定义batch训练

    3)如果你的数据量很大,你可能要用到fit_generator

    def generate_arrays_from_file(path):

        while 1:

            f = open(path)

            for line in f:

                x, y = process_line(line)

                img = load_images(x)

                yield (img, y)

            f.close()

    model.fit_generator(generate_arrays_from_file('/my_file.txt'), samples_per_epoch=10000, nb_epoch=10)

    测试集评估与预测

    在测试集上评估效果

    loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128)

    实际预测

    classes = model.predict(x_test, batch_size=128)

    优化器optimizer、损失函数loss、评估指标metrics

    # 多分类问题

    model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

    # 二分类问题

    model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])

    # 回归问题

    model.compile(optimizer='rmsprop', loss='mse')

    # 自定义metrics

    import keras.backend as K

    def mean_pred(y_true, y_pred):

        return K.mean(y_pred)

    model.compile(optimizer='rmsprop',

                  loss='binary_crossentropy',

                  metrics=['accuracy', mean_pred])

     

  • 相关阅读:
    java之JDBC
    git删除未监视的文件
    java之正则表达式
    linux命令之信息显示与搜索文件命令
    linux命令之文件备份与压缩命令
    gitlab中修改项目名称客户端修改方法
    linux中使用unzip命令中文乱码解决办法
    使用Python进行统计量描述
    Machine Learning
    Courase Neural Networks for Machine Learning Lecture1 Note
  • 原文地址:https://www.cnblogs.com/yongfuxue/p/10095895.html
Copyright © 2011-2022 走看看