zoukankan      html  css  js  c++  java
  • tflearn 在每一个epoch完毕保存模型

    关键代码:
    tflearn.DNN(net, checkpoint_path='model_resnet_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.)
    snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
    我的demo:
    def get_model(width, height, classes=40):
        # TODO, modify model
        network = input_data(shape=[None, width, height, 3])  # if RGB, 224,224,3
        # Residual blocks  
        # 32 layers: n=5, 56 layers: n=9, 110 layers: n=18  
        n = 2
        net = tflearn.conv_2d(network, 16, 3, regularizer='L2', weight_decay=0.0001)
        net = tflearn.residual_block(net, n, 16)
        net = tflearn.residual_block(net, 1, 32, downsample=True)
        net = tflearn.residual_block(net, n-1, 32)
        net = tflearn.residual_block(net, 1, 64, downsample=True)
        net = tflearn.residual_block(net, n-1, 64)
        net = tflearn.batch_normalization(net)
        net = tflearn.activation(net, 'relu')
        net = tflearn.global_avg_pool(net)
        # Regression  
        net = tflearn.fully_connected(net, classes, activation='softmax')
        #mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=32000, staircase=True)
        mom = tflearn.Momentum(0.01, lr_decay=0.1, decay_step=2000, staircase=True)
        net = tflearn.regression(net, optimizer=mom,
                                 loss='categorical_crossentropy')
        # Training  
        model = tflearn.DNN(net, checkpoint_path='model_resnet_cifar10',
                            max_checkpoints=10, tensorboard_verbose=0,
                            clip_gradients=0.)
        return model
    
    
    
    def  main():
        trainX, trainY = image_preloader("data/train", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
        testX, testY = image_preloader("data/test", image_shape=(width, height, 3), mode='folder', categorical_labels=True, normalize=True)
        #trainX = trainX.reshape([-1, width, height, 1])
        #testX = testX.reshape([-1, width, height, 1])
        print("sample data:")
        print(trainX[0])
        print(trainY[0])
        print(testX[-1])
        print(testY[-1])
    
        model = get_model(width, height, classes=3755)
    
        filename = 'tflearn_resnet/model.tflearn'
        # try to load model and resume training
        try:
            #model.load(filename)
            model.load("model_resnet_cifar10-195804")
            print("Model loaded OK. Resume training!")
        except:
            pass
    
        early_stopping_cb = EarlyStoppingCallback(val_acc_thresh=0.94)
        try:      
            model.fit(trainX, trainY, validation_set=(testX, testY), n_epoch=500, shuffle=True,
                      snapshot_epoch=True, # Snapshot (save & evaluate) model every epoch.
                      show_metric=True, batch_size=1024, callbacks=early_stopping_cb, run_id='cnn_handwrite')
        except StopIteration as e:
            print("OK, stop iterate!Good!")
    
        model.save(filename)
    
        del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
        filename = 'tflearn_resnet/model-infer.tflearn'
        model.save(filename)
    
  • 相关阅读:
    笔记-树形dp
    20181018 考试记录
    20181015 考试记录&数论
    [模板]2-SAT 问题&和平委员会
    FLask的偏函数应用
    Flask中的g到底是个什么鬼?
    Flask源码关于local的实现
    Flask的“中间件”
    Flask之模板
    FLask之视图
  • 原文地址:https://www.cnblogs.com/bonelee/p/9006243.html
Copyright © 2011-2022 走看看