zoukankan      html  css  js  c++  java
  • tflearn 保存模型重新训练

    from:https://stackoverflow.com/questions/41616292/how-to-load-and-retrain-tflean-model

    This is to create a graph and save it

    graph1 = tf.Graph()
    with graph1.as_default():
        network = input_data(shape=[None, MAX_DOCUMENT_LENGTH])
        network = tflearn.embedding(network, input_dim=n_words, output_dim=128)
        branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
        branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
        branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
        network = merge([branch1, branch2, branch3], mode='concat', axis=1)
        network = tf.expand_dims(network, 2)
        network = global_max_pool(network)
        network = dropout(network, 0.5)
        network = fully_connected(network, 2, activation='softmax')
        network = regression(network, optimizer='adam', learning_rate=0.001,loss='categorical_crossentropy', name='target')
        model = tflearn.DNN(network, tensorboard_verbose=0)
        clf, acc, roc_auc,fpr,tpr =classify_DNN(data,clas,model)
        clf.save(model_path)
    

    To reload and retrain or use it for prediction

    MODEL = None
    with tf.Graph().as_default():
    ## Building deep neural network
        network = input_data(shape=[None, MAX_DOCUMENT_LENGTH])
        network = tflearn.embedding(network, input_dim=n_words, output_dim=128)
        branch1 = conv_1d(network, 128, 3, padding='valid', activation='relu', regularizer="L2")
        branch2 = conv_1d(network, 128, 4, padding='valid', activation='relu', regularizer="L2")
        branch3 = conv_1d(network, 128, 5, padding='valid', activation='relu', regularizer="L2")
        network = merge([branch1, branch2, branch3], mode='concat', axis=1)
        network = tf.expand_dims(network, 2)
        network = global_max_pool(network)
        network = dropout(network, 0.5)
        network = fully_connected(network, 2, activation='softmax')
        network = regression(network, optimizer='adam', learning_rate=0.001,loss='categorical_crossentropy', name='target')
        new_model = tflearn.DNN(network, tensorboard_verbose=3)
        new_model.load(model_path)
        MODEL = new_model
    

    Use the MODEL for prediction or retraining. The 1st line and the with loop was important. For anyone who might need help

    官方例子:

    """ An example showing how to save/restore models and retrieve weights. """
    
    from __future__ import absolute_import, division, print_function
    
    import tflearn
    
    import tflearn.datasets.mnist as mnist
    
    # MNIST Data
    X, Y, testX, testY = mnist.load_data(one_hot=True)
    
    # Model
    input_layer = tflearn.input_data(shape=[None, 784], name='input')
    dense1 = tflearn.fully_connected(input_layer, 128, name='dense1')
    dense2 = tflearn.fully_connected(dense1, 256, name='dense2')
    softmax = tflearn.fully_connected(dense2, 10, activation='softmax')
    regression = tflearn.regression(softmax, optimizer='adam',
                                    learning_rate=0.001,
                                    loss='categorical_crossentropy')
    
    # Define classifier, with model checkpoint (autosave)
    model = tflearn.DNN(regression, checkpoint_path='model.tfl.ckpt')
    
    # Train model, with model checkpoint every epoch and every 200 training steps.
    model.fit(X, Y, n_epoch=1,
              validation_set=(testX, testY),
              show_metric=True,
              snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
              snapshot_step=500, # Snapshot (save & evalaute) model every 500 steps.
              run_id='model_and_weights')
    
    
    # ---------------------
    # Save and load a model
    # ---------------------
    
    # Manually save model
    model.save("model.tfl")
    
    # Load a model
    model.load("model.tfl")
    
    # Or Load a model from auto-generated checkpoint
    # >> model.load("model.tfl.ckpt-500")
    
    # Resume training
    model.fit(X, Y, n_epoch=1,
              validation_set=(testX, testY),
              show_metric=True,
              snapshot_epoch=True,
              run_id='model_and_weights')
    
    
    # ------------------
    # Retrieving weights
    # ------------------
    
    # Retrieve a layer weights, by layer name:
    dense1_vars = tflearn.variables.get_layer_variables_by_name('dense1')
    # Get a variable's value, using model `get_weights` method:
    print("Dense1 layer weights:")
    print(model.get_weights(dense1_vars[0]))
    # Or using generic tflearn function:
    print("Dense1 layer biases:")
    with model.session.as_default():
        print(tflearn.variables.get_value(dense1_vars[1]))
    
    # It is also possible to retrieve a layer weights through its attributes `W`
    # and `b` (if available).
    # Get variable's value, using model `get_weights` method:
    print("Dense2 layer weights:")
    print(model.get_weights(dense2.W))
    # Or using generic tflearn function:
    print("Dense2 layer biases:")
    with model.session.as_default():
        print(tflearn.variables.get_value(dense2.b))
  • 相关阅读:
    设计模式的分类
    设计模式工厂方法模式
    设计模式的定义
    帕斯卡命名法
    C#编写程序找一找一个二维数组中的鞍点(即该位置上的元素值在行中最大,在该 列上最小。有可能数组没有鞍点)。要求:1.二维数组的大小、数组元素的值在运行时输入;2.程序有友好的提示信息。
    设计模式抽象工厂模式
    设计模式七大原则
    理解C#中的接口
    Linux下如何查看CPU信息, 包括位数和多核信息
    关于结构体内存对齐
  • 原文地址:https://www.cnblogs.com/bonelee/p/8582381.html
Copyright © 2011-2022 走看看