zoukankan      html  css  js  c++  java
  • 【小白学PyTorch】15 TF2实现一个简单的服装分类任务

    【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测、医学图像、时间序列等多个目标为技术学习的分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.

    参考目录:

    0 为什么学TF

    之前的15节课的pytorch的学习,应该是让不少朋友对PyTorch有了一个全面而深刻的认识了吧 (如果你认真跑代码了并且认真看文章了的话)

    大家都会比较Tensorflow2和pytorch之间孰优孰劣,但是我们也并不是非要二者选一,两者都是深度学习的工具,其实我们或多或少应该了解一些比较好。 就好比,PyTorch是冲锋枪,TensorFlow是步枪,在上战场前,我们可以选择带上冲锋枪还是步枪,但是在战场上,可能手中的枪支没有子弹了,你只能在地上随便捡了一把枪。 很多时候,用Pytorch还是Tensorflow的选择权不在自己。

    此外,了解了TensorFlow,大家才能更好的理解PyTorch和TF究竟有什么区别。我见过有的大佬是TF和PyTorch一起用在一个项目中,数据读取用PyTorch然后模型用TF构建。

    总之,大家有时间有精力的话,顺便学学TF也不亏,更何况TF2.0现在已经优化了很多。本系列预计用3节课来简单的入门一下Tensorflow2.

    和PyTorch的第一课一样,我们直接做一个简单的小实战。MNIST手写数字分类,Fashion MNIST时尚服装分类。

    1 Tensorflow的安装

    安装TensorFlow的方法很简单,就是在控制台执行:

    pip install tensorflow --user
    

    这里的--user是赋予这个命令执行权限的,一般我都会带上。

    2 数据集构建

    # keras是TF的高级API,用起来更加的方便,一般也是用keras。
    import tensorflow as tf
    from tensorflow import keras
    import numpy as np
    

    导入需要用到的库函数. 正如torchvision.datasets中一样,keras.datasets中也封装了一些常用的数据集。

    fashion_mnist = keras.datasets.fashion_mnist
    (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
    print('train_images shape:',train_images.shape)
    print('train_labels shape:',train_labels.shape)
    print('test_images shape:',test_images.shape)
    print('test_labels shape:',test_labels.shape)
    

    输出结果是:

    训练数据集中有60000个样本,每一个样本和MNIST手写数字大小是一样的,是(28 imes 28)大小的,然后每一个样本有一个标签,这个标签和MNIST也是一样的,是从0到9,是一个十分类任务。

    来看一下这些类别有哪些:

    标签 类别 标签 类别
    0 T-shirt 5 Sandal
    1 Trouser 6 Shirt
    2 Pullover 7 Sneaker
    3 Dress 8 Bag
    4 Coat 9 Ankle boot

    这里学学单词吧:

    • T-shirt就是T型的衬衫,就是短袖,我感觉前面没有扣子的那种也叫T-shirt;
    • Shirt就是长袖的那种衬衫;
    • Trouser是裤子;
    • pullover是毛衣,套头毛衣,就是常说的卫衣吧感觉;
    • dress连衣裙;
    • coat是外套;
    • sandal是凉鞋;
    • sneaker是运动鞋;
    • ankle boot是短靴,是到脚踝的那种靴子;
    • 这里补充一个吧,sweater,是毛线衣,运动衫,这个和pullover有些类似,个人感觉主要的区分在于运动系列的可以叫做sweater,其他的毛衣卫衣是pullover。

    运动短袖T-shirt+运动卫衣sweater是我秋天去健身房的穿搭。

    2 预处理

    这里不做图像增强之类的了,上面的数据中,图像像素值是从0到255的,我们要把这些标准化成0到1的范围。

    train_images = train_images / 255.0
    test_images = test_images / 255.0
    

    3 构建模型

    # 模型搭建
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=(28, 28)),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(10, activation='softmax')
    ])
    

    这就是一个用keras构建简单模型的例子:

    • keras.layers.Flatten是把(28 imes 28)的二维度拉平成一个维度,因为这里是直接用全连接层而不是卷积层进行处理的;
    • 后面跟上两个全连接层keras.layers.Dense()就行了。我们可以发现,这个全连接层的参数和PyTorch是有一些区别的:
      1. PyTorch的全连接层需要一个输入神经元数量和输出数量torch.nn.Linear(5,10),而keras中的Dense是不需要输入参数的keras.layers.Dense(10)
      2. keras中的激活层直接封装在了Dense函数里面,所以不需要像PyTorch一样单独写一个nn.ReLU()了。

    4 优化器

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    

    定义优化器和损失函数,在keras中叫做对模型进行编译compile(在C语言中,在运行代码之前都需要对代码进行编译嘛)。损失函数和优化器还有metric衡量指标的设置都在模型的编译函数中设置完成。

    上面使用Adam作为优化器,然后损失函数用了交叉熵,然后衡量模型性能的使用了准确率Accuracy。

    5 训练与预测

    model.fit(train_images, train_labels, epochs=10)
    

    这就是训练过程,相比PyTorch而言,更加的简单简洁,但是不像PyTorch那样灵活。

    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print('
    Test accuracy:', test_acc)
    

    这个.evaluate方法是对模型的验证集进行验证的,因为本次任务中并没有对训练数据再划分出验证集,所以这里直接使用测试数据了。

    大家应该能理解训练集、验证集和测试集的用途和区别吧,我在第二课讲过这个内容,在此不多加赘述。

    predictions = model.predict(test_images)
    

    这个.predict方法才是用在测试集上,进行未知标签样本的类别推理的。

    本次内容到此为止,大家应该对keras和tensorflow有一个直观浅显的认识了。当然tensorflow也有一套类似于PyTorch中的dataset,dataloader的那样自定义的数据集加载器的方法,在后续内容中会深入浅出的学一下。

  • 相关阅读:
    全面理解javascript的caller,callee,call,apply概念(修改版)
    动态显示更多信息(toggle_visible函数的运用)
    再论call和apply
    RSS News Module的应用
    准备制作一套全新的DNN皮肤(包括个人定制或企业级定制)
    ControlPanel研究系列二:简单Ajax模式的ControlPanel(SimplAjax)
    New_Skin发布了....
    如何定制dnn的FckEditor
    Blog已迁移到dnnsun.com(最新DotNetNuke咨询平台)
    新DNN皮肤的经验及成果分享
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/13749213.html
Copyright © 2011-2022 走看看