zoukankan      html  css  js  c++  java
  • TensorFlow-Slim 简介+Demo

    github介绍:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim

    基于slim实现的yolo-v3(测试可用):https://github.com/mystic123/tensorflow-yolo-v3

    简介

    • TF-Slim是一个轻量级tensorflow库。
    • 它可以使复杂模型的定义、训练、评估测试更简单。
    • 它的组件,可以与tensorflow的其他库(如tf.contrib.learn)混合使用。
    • 它允许用户更紧凑地定义模型,通过消除样板代码(boilerplate code)。

    Demo

    import tensorflow as tf
    from tensorflow.contrib.layers.python.layers import layers as layers_lib
    from tensorflow.contrib import layers
    import tensorflow.contrib.slim as slim
    from keras.datasets import mnist
    import numpy as np
    import math
    
    print("Hello slim.")
    pixel_depth = 256
    learning_rate = 0.01
    checkpoint_dir = "./ckpts/"
    log_dir = "./logs/"
    batch_size = 1000
    
    # Get the data, mnist.npz is in ~/.keras/datasets/mnist.npz
    print("Loading the MNIST data in ~/.keras/datasets/mnist.npz")
    (train_data, train_labels), (test_data, test_labels) = mnist.load_data()
    train_data   = train_data  .reshape(-1,28,28,1).astype(np.float32)
    train_labels = train_labels.reshape(-1)        .astype(np.int64)
    test_data   = test_data  .reshape(-1,28,28,1).astype(np.float32)
    test_labels = test_labels.reshape(-1)        .astype(np.int64)
    
    train_data = 2.0*train_data/pixel_depth - 1.0
    test_data  = 2.0*test_data /pixel_depth - 1.0
    
    train_data   = train_data[0:10000]
    train_labels = train_labels[0:10000]
    
    print("train data shape:", train_data.shape)
    print("test  data shape:", test_data.shape)
    
    # slim.nets.vgg.vgg_16
    def MyModel(inputs, num_classes=10, is_training=True, dropout_keep_prob=0.5, spatial_squeeze=False, scope='MyModel'):
        with tf.variable_scope(scope):
            with slim.arg_scope([slim.conv2d, slim.fully_connected],
                                activation_fn=tf.nn.relu,
                                weights_initializer=tf.truncated_normal_initializer(0.0, 0.01),
                                weights_regularizer=slim.l2_regularizer(0.0005)):
                
                net = slim.convolution2d(inputs, 8, [3, 3], 1, padding='SAME', scope='conv1')
                net = layers_lib.max_pool2d(net, [2, 2], scope='pool1')
                
                net = slim.convolution2d(net, 8, [5, 5], 1, padding='SAME', scope='conv2')
                net = layers_lib.max_pool2d(net, [2, 2], scope='pool2')
            
                net = slim.flatten(net, scope='flatten1')
                
                net = slim.fully_connected(net, num_classes*num_classes, activation_fn=None, scope='fc1')
                net = slim.fully_connected(net, num_classes, activation_fn=None, scope='fc2')
                
        return net
    
    def train_data_batch(batch_size):
        if not hasattr(train_data_batch, 'train_index'):
            train_data_batch.train_index = 0
        data_size = train_labels.shape[0]
        idx = np.arange(train_data_batch.train_index, train_data_batch.train_index+batch_size, 1)
        idx = idx % data_size
        train_data_batch.train_index = (train_data_batch.train_index + batch_size) % data_size
        yield train_data[idx]
    
    logits = MyModel(train_data)
    loss = slim.losses.sparse_softmax_cross_entropy(logits, train_labels)
    
    total_loss = slim.losses.get_total_loss(add_regularization_losses=False)
    
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    
    train_op = slim.learning.create_train_op(total_loss, optimizer)
    slim.learning.train(train_op,
                        checkpoint_dir,
                        number_of_steps=100,
                        save_summaries_secs=5,
                        save_interval_secs=10)
    
    
    print("See you, slim.")
  • 相关阅读:
    ingress-nginx-controller 504 gateway time-out 问题
    ansible的shell模板使用awk包含引号的问题
    Python selenium模块报错解决
    redis密码破解(multiprocessing的Pool多进程模式)-join方法小坑
    redis密码破解(Python使用multiprocessing分布式进程)
    redis密码破解(python使用redis模块)
    redis密码破解(python使用socket模块)
    修改云主机快照方式为live snapshot
    虚拟机重启错误,libvirtError:internal error:process exited while connecting to monitor
    如何解决高并发秒杀的超卖问题
  • 原文地址:https://www.cnblogs.com/xbit/p/10059745.html
Copyright © 2011-2022 走看看