zoukankan      html  css  js  c++  java
  • 使用tensorflow搭建一个神经网络,实现一个分类问题

    工欲善其事必先利其器,首先,我们来说说关于环境搭建的问题。

    安装的方法有一万种,但是我还是推荐下面这种安装方法,简单方便,不会出现很多莫名其妙的问题。
    Anaconda + Jupyter + tensorflow

    安装的具体流程见下面的视频链接:
    https://www.youtube.com/watch?v=G2GqLWOERjQ (需要科学上网)

    数据集

    数据集采用的比利时这个国家的交通标志。从 https://btsd.ethz.ch/shareddata/ 可以获得数据, BelgiumTSC_Training (171.3MBytes)和 BelgiumTSC_Testing (76.5MBytes)分别代表我们的训练数据和测试数据。

    数据集的说明

    Trainging文件夹中有62个文件夹,每一个文件夹中若干张图片,文件夹中图片就是我们的属性,标签是文件夹的名字。
    我们的训练目标就是,给定一张图片,判断这张图片属于哪一个文件夹(分类问题)。

    上干货,代码!

    -加载数据,并创建训练集的属性和标签

    def load_data(data_dir):
        # Get all subdirectories of data_dir. Each represents a label.
        directories = [d for d in os.listdir(data_dir) 
                       if os.path.isdir(os.path.join(data_dir, d))]
    #     print(directories)
        # Loop through the label directories and collect the data in
        # two lists, labels and images.
        labels = []
        images = []
        for d in directories:
            label_dir = os.path.join(data_dir, d)
            file_names = [os.path.join(label_dir, f) 
                          for f in os.listdir(label_dir) 
                          if f.endswith(".ppm")]
            for f in file_names:
                images.append(data.imread(f))
                labels.append(int(d))
        return images, labels
    
    ROOT_PATH = "E:/machineLearning/tensorflow/data/"  #这里需要根据自己数据存放的路径进行修改
    train_data_dir = os.path.join(ROOT_PATH, "BelgiumTSC_Training/Training")
    test_data_dir = os.path.join(ROOT_PATH, "BelgiumTSC_Testing/Testing")
    
    images, labels = load_data(train_data_dir)
    
    images_array = np.array(images)
    labels_array = np.array(labels)
    
    # Print the `images` dimensions
    print(images_array.ndim)
    
    # Print the number of `images`'s elements
    print(images_array.size)
    
    # Print the first instance of `images`
    # print(images_array[0])
    
    # Print the `labels` dimensions
    print(labels_array.ndim)
    
    # Print the number of `labels`'s elements
    print(labels_array.size)
    
    # Count the number of labels
    print(len(set(labels_array)))
    

    -特征抽取
    缩放图像:

    # Resize images
    images32 = [transform.resize(image, (28, 28)) for image in images]
    images32 = np.array(images32)
    print(images32[0])
    

    将彩色图像灰度化

    for i in range(len(traffic_signs)):
        plt.subplot(1, 4, i+1)
        plt.axis('off')
        plt.imshow(images32[traffic_signs[i]], cmap="gray")
        plt.subplots_adjust(wspace=0.5)
        
    plt.show()
    
    print(images32.shape)
    

    -使用Tensorflow训练一个神经网络

    import tensorflow as tf
    x = tf.placeholder(dtype = tf.float32, shape = [None, 28, 28])
    y = tf.placeholder(dtype = tf.int32, shape = [None])
    images_flat = tf.contrib.layers.flatten(x)
    logits = tf.contrib.layers.fully_connected(images_flat, 62, tf.nn.relu)
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y, logits = logits))
    train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
    correct_pred = tf.argmax(logits, 1)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    
    print("images_flat: ", images_flat)
    print("logits: ", logits) 
    print("loss: ", loss)
    print("predicted_labels: ", correct_pred)
    
    

    运行神经网络

    sess = tf.Session()
    
    sess.run(tf.global_variables_initializer())
    
    for i in range(201):
            print('EPOCH', i)
            _, accuracy_val = sess.run([train_op, accuracy], feed_dict={x: images32, y: labels})
            if i % 10 == 0:
                print("Loss: ", loss)
            print('DONE WITH EPOCH')
    

    执行神经网络

    # Pick 10 random images
    sample_indexes = random.sample(range(len(images32)), 10)
    sample_images = [images32[i] for i in sample_indexes]
    sample_labels = [labels[i] for i in sample_indexes]
    
    # Run the "predicted_labels" op.
    predicted = sess.run([correct_pred], feed_dict={x: sample_images})[0]
                            
    # Print the real and predicted labels
    print(sample_labels)
    print(predicted)
    

    -展示预测结果

    # Display the predictions and the ground truth visually.
    fig = plt.figure(figsize=(10, 10))
    for i in range(len(sample_images)):
        truth = sample_labels[i]
        prediction = predicted[i]
        plt.subplot(5, 2,1+i)
        plt.axis('off')
        color='green' if truth == prediction else 'red'
        plt.text(40, 10, "Truth:        {0}
    Prediction: {1}".format(truth, prediction), 
                 fontsize=12, color=color)
        plt.imshow(sample_images[i])
    
    plt.show()
    

    -预测测试集

    # Load the test data
    test_images, test_labels = load_data(test_data_dir)
    
    # Transform the images to 28 by 28 pixels
    test_images28 = [transform.resize(image, (28, 28)) for image in test_images]
    
    # Convert to grayscale
    from skimage.color import rgb2gray
    test_images28 = rgb2gray(np.array(test_images28))
    
    # Run predictions against the full test set.
    predicted = sess.run([correct_pred], feed_dict={x: test_images28})[0]
    
    # Calculate correct matches 
    match_count = sum([int(y == y_) for y, y_ in zip(test_labels, predicted)])
    
    # Calculate the accuracy
    accuracy = match_count / len(test_labels)
    
    # Print the accuracy
    print("Accuracy: {:.3f}".format(accuracy))
    

    -关闭session

    sess.close()
    

    预测的准确率大概是57.8%

  • 相关阅读:
    Python(错误的处理方法)
    WP7 QQ词典V1.1 共享源代码
    【原创】Windows Phone真机抓包并分析应用程序的网络通讯(Android、iPhone也适用)
    使用uiautomation自动化重命名pdf书签,使全大写字母变成首字母大写
    分享python分析wave, pcm音频文件
    最近在做的一个wp7地图应用
    HTML5初探 基本的HTML5模版
    HTML5页面架构元素 <header>标签
    Wix 打包(1)[转载]
    windows 域验证 IIS7.0
  • 原文地址:https://www.cnblogs.com/share-sjb/p/10012163.html
Copyright © 2011-2022 走看看