zoukankan      html  css  js  c++  java
  • Tensorflow最简单实现ResNet50残差神经网络,进行图像分类,速度超快

    在图像分类领域内,其中的大杀器莫过于Resnet50了,这个残差神经网络当时被发明出来之后,顿时毁天灭敌,其余任何模型都无法想与之比拟。我们下面用Tensorflow来调用这个模型,让我们的神经网络对Fashion-mnist数据集进行图像分类.由于在这个数据集当中图像的尺寸是28*28*1的,如果想要使用resnet那就需要把28*28*1的灰度图变为224*224*3的RGB图,我们使用OpenCV库可以很容易将图像进行resize。

    首先我们进行导包:

    import os,sys
    import numpy as np
    import scipy
    from scipy import ndimage
    import tensorflow as tf
    import matplotlib.pyplot as plt
    from tensorflow.keras.applications.resnet50 import ResNet50
    from tensorflow.keras.preprocessing import image
    from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
    from PIL import Image
    import random

    加载数据集:

    (train_image,train_label),(test_image,test_label)=tf.keras.datasets.fashion_mnist.load_data()

    导入opencv并重命名:

    import cv2 as cv

    读取数据集当中的500张图片(注意不要使用所有的图片进行读取和resize,不然电脑的内存将会不存,因为resize之后每一张图片的尺寸大大增加,60000张图片所需要的电脑内存大致需要8.1Gb,使用CPU进行训练的话,你的内存条也就需要目前空余的至少在8gb以上,后期加上resnet的权重参数那更是几个亿,电脑的运行内存是不可能这么大的,毕竟只要我们的神经网络好,几个epoch就可以得到很好的验证集准确度了,没有必要追求数量),读取和同时进行resize为224*224*3的代码如下:

    train_data = []
    for img in train_image[:500]:
        resized_img = cv.resize(img, (224, 224))
        resized_img = cv.cvtColor(resized_img, cv.COLOR_GRAY2BGR)
        train_data.append(resized_img)

    我们最后得到了一个三维的列表数据,但是这并不是一个ndarray,也就是numpy当中的数组对象,还无法进行训练,我们需要将其转化为numpy当中的数组,代码如下:

    train_data=np.array(train_data)
    train_data.shape

    输出目前的shape为:

    (500, 224, 224, 3)

    将数据进行归一化,加速卷积神经网络的运算:

    train_data=train_data/255

    导入Resnet50模型,同时编译模型:

    model = ResNet50(
        weights=None,
        classes=10
    )
    
    model.compile(optimizer="Adam",
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    开始拟合模型:

    model.fit(train_data,train_label[0:500], epochs=10, batch_size=6)

    输出:

    Train on 500 samples
    Epoch 1/10
    500/500 [==============================] - 256s 511ms/sample - loss: 1.5721 - accuracy: 0.4260
    Epoch 2/10
    500/500 [==============================] - 255s 511ms/sample - loss: 1.3282 - accuracy: 0.5600
    Epoch 3/10
    500/500 [==============================] - 260s 519ms/sample - loss: 1.1301 - accuracy: 0.6180
    Epoch 4/10
    500/500 [==============================] - 259s 519ms/sample - loss: 1.1403 - accuracy: 0.6080
    Epoch 5/10
    500/500 [==============================] - 261s 521ms/sample - loss: 1.0098 - accuracy: 0.6400
    Epoch 6/10
    500/500 [==============================] - 264s 528ms/sample - loss: 0.9646 - accuracy: 0.6860
    Epoch 7/10
    500/500 [==============================] - 268s 535ms/sample - loss: 0.8954 - accuracy: 0.6940
    Epoch 8/10
    500/500 [==============================] - 269s 539ms/sample - loss: 0.7415 - accuracy: 0.7540
    Epoch 9/10
    500/500 [==============================] - 274s 549ms/sample - loss: 0.7001 - accuracy: 0.7880
    Epoch 10/10
    500/500 [=============================] - 275s 551ms/sample - loss: 0.5996 - accuracy: 0.8020

    从中可以发现只需要500张图片,进行十次epoch,训练集的准确度已经达到百分之八十。

    这样我们就使用tensorflow2.0快速实现了一个Resnet50的神经网络了!

  • 相关阅读:
    java编程基础--方法
    MySQL中使用LIMIT进行分页的方法
    Java编程基础--数据类型
    Java开发入门
    SpringBoot实战项目(十七)--使用拦截器实现系统日志功能
    SpringBoot实战项目(十六)--拦截器配置及登录拦截
    SpringBoot实战项目(十五)--修改密码及登录退出功能实现
    SpringBoot实战项目(十四)--登录功能之登录表单验证
    PHP setcookie 网络函数
    PHP mysqli_kill MySQLi 函数
  • 原文地址:https://www.cnblogs.com/geeksongs/p/13215624.html
Copyright © 2011-2022 走看看