zoukankan      html  css  js  c++  java
  • keras_API汇总积累(熟读手册)一,快速开始

    一,快速开始

    建立(sequential和add两种方式)

    1.1sequential

    from keras.models import Sequential

    from keras.layers  import Dense,Activation,Dropout,Flatten,Conv2D,MaxPooling2D,Embedding,LSTM,Conv1D,GlobalAveragePooling1D,MaxPooling1D

    from keras.optimizers import SGD

    model=Sequential([Dense(32,input_shape=(784,)),Activation('relu'),Dsense(10),Activation('softmax')])#在一个sequential模型中添加了两个全连接层和两个激活函数

    1.2add

    model.Sequential()

    model.add(Dense(32,input_dim=784))#3D时,还需要input_length

    model.add(Activation('relu'))

    1.3other

    model.add(Embedding(max_features,output_dim=256))

    model.add(LSTM(128,return_sequences=True, stateful=True))

    model.add(Conv2D(32,(3,3),activation='relu',input_shape=(100,100,3)))

    model.add(MaxPooling2D(pool_size=(2,2)))

    model.add(Flatten())

    model.add(Dropout(0.5))

    model.add(Conv1D(64,3,activation='relu',input_shape=(64,100)))

    model.add(MaxPooling1D(3)

    model.add(GlobalAveragePooling1D())

    编译(三个参数:optimizer,loss,metrics(可以自定义评估函数eg:mean_pred))

    自己设置优化器:sgd=SGD(lr=0.01,decay=1e-6,momentum=0.9,nesrerov=True)

    model.compile(loss='categorical_crossentropy',optimizer='rmsprop',metrics=['accuracy',mean_pred])

     compile(optimizer, loss=None, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None)

    训练

    model.fit(x,y,batch_size=32,epochs=10)

    fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

    train_on_batch(x, y, sample_weight=None, class_weight=None)

    test_on_batch(x, y, sample_weight=None)

    评估

    score=model.evaluate(x_test,y_test,batch_size=32)

    evaluate(x=None, y=None, batch_size=None, verbose=1, sample_weight=None, steps=None)

    预测

    predict(x, batch_size=None, verbose=0, steps=None)

    a.打印出网络每一层输入输出详情的png图片:先装graphviz,链接:https://pan.baidu.com/s/1u64HriYy4KQ_BhJE8MMTkA ,提取码:zqxa》》》》安装到电脑后,打开Graphviz2.38in,复制目录添加到系统环境变量的path中》》》》cmd:pip install pydot=1.2.3》》》》cmd: dot -version查看安装成功》》》》keras.utils.plot_model(model, 'model.png', show_shapes=True) 》》》就会把你的model保存到文件夹下,名字model.png。

    b.标签转换为one-hot:one_hot_labels=keras.utils.to_categorical(y,num_classes=10)

  • 相关阅读:
    远程办公的一天:魔幻24小时
    LVS:三种负载均衡方式比较
    程序员的二十句励志名言,看看你最喜欢哪句?
    个人服务器开通~
    jquery大全
    CSS大全
    英语中的连词说明
    高版本SqlServer转低版本SqlServer经验总结
    SQLServer中,sa帐号旁边有个小红箭头
    Entity Framework GroupBy usage
  • 原文地址:https://www.cnblogs.com/Turing-dz/p/13030192.html
Copyright © 2011-2022 走看看