zoukankan      html  css  js  c++  java
  • Python 实现深度学习(2)

    第一篇:基础知识简介


    第一篇是基础知识简介,对于过于简单的知识点,不会详细叙述,分为两部分:

    1. python基础知识:将后期需要的了解的知识点列出,并给出相关资料。

    2. 神经网络基础知识:感知机是神经网络的前身,对感知机简单的介绍。

    本篇的目的和内容主要为: 介绍感知机和python;


    一、Python基础知识

    本章会列出实现的神经网络所需要的基础知识,并给出参考资料


    TODO:

    介绍numpy库和matplotlib库、读写二进制的方法、pkl等。这些知识会在后面用到,在本篇的最后会以mnist数据集为例,创建处理手写体图片的函数,供后使用。

    1. class 和function

    2. numpy

    3. Matplotlib

    4. 序列化

    image



    1.1 class 和function

    python3 函数

    python3 面向对象


    1.2 Numpy

    Nump(Numerical Python)是Python的运算库,支持大规模的数组和矩阵运算。在深度学习的实现中会使用矩阵进行计算,numpy中实现了很多数据组的运算方法,在后期会用到的有:

    • Nmupy的数据结构ndarray
    • Numpy的切片和索引
    • 广播功能函数
    • 算术函数


    1.2.1 Ndarray

    Numpy中主要的数据结构是Ndarray,用于存放同类型元素的多维数组。

    Ndarray的内部如图1所示。


    image

    fig1. ndarray 的数据结构

    数据类型:dtype,描述数据类型,可以计算每个元素大小;

    数组形状:shape,描述数组的大小和形状;

    跨度元组,stride:表示从前一个维度到下一个维度需要跨越的字节数;

    data: 指向数组的地址;

    ps: 后期会用到dtype, shape等成员变量



    1.2.2 广播

    Numpy对于不同形状的乘法采用了广播机制。

    广播可以对不同形状的数组做点乘:将较小的形状按照一定的规则填充,填充的方向依次为由内向外;广播机制在cudnn、tensorflow等深度学习框架中同样会使用。

    广播是一种ufunc的机制是 不同形状的数组之间执行算数运算的方式,需要遵循4个原则:

    • 1.让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐
    • 2.输入数组的shape是输入数组shape的各个轴上的最大值
    • 3.如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错。
    • 4.输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值。

    举例说明:

    假设,两个矩阵要做乘法,第一个矩阵是2*2, 但第二个矩阵并不是2*2的,按照数学运算法则是不能做点乘的

    但如果有广播机制,会按照以下方式填充数据,并做乘法:

    先从行广播,然后再从列广播,举例如下


    Case1: 行列都不一致。先填充行,再填充列。

    image


    Case2: 行不一致,列一致。先填充行

    image

    image


    Case3: 行一致,列不一致。由于行已经一致了,不需要填充,直接填充列。

    image


    Case4: 行列不一致,且有一个维度无法广播

    image

    更多关于广播机制,详见: basics.broadcasting


    1.2.3 其他知识点

    numpy的切片和索引的有关内容在 fancy-indexing-and-index-tricks 中可以找到。

    至于算术运算等网上的资料已经足够多的了,不需要我再重复操作了,这里给出一个官方的资料:numpy-quickstart.html


    1.3 matplotlib和skimage

    matplotlib和skimage在可视化数据的时候会用到。网上的资料足够多的了,在此不多介绍,给出参考资料:

    https://www.runoob.com/w3cnote/matplotlib-tutorial.html

    https://www.runoob.com/numpy/numpy-matplotlib.html

    scikit-image

    https://cloud.tencent.com/developer/section/1414638



    1.4 python序列化

    Serialization序列化,是将内存中对象以二进制的方式存储起来,存到磁盘。如果将磁盘中的文件解析成一个对象,这个过程称为deSerialization。序列化的数据可以用于网络传输,不会因为编码方式而改变。Python中的序列化由pickle模块实现。以下是参考资料:

    pickle:Python object serialization

    https://docs.python.org/zh-cn/2.7/library/pickle.html






    1.5 实践:mnist数据集解析

    本章会撰写程序实现一下功能:

    1.下载mnsit数据集,解析mnist数据放在numpy的array中;

    2.将解析的数据先序列化,然后持久化

    3.反序列化,读取mnist中的一样图像,用plt或者skimage显示。


      1 # -*- coding: utf-8 -*-
      2 # @File  : mnist.py
      3 # @Author: lizhen
      4 # @Date  : 2020/2/4
      5 # @Desc  : 工具类,datasets/mnist.py
      6 
      7 import urllib.request # python3
      8 import os.path
      9 import gzip
     10 import pickle
     11 import os
     12 import numpy as np
     13 
     14 # http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
     15 # http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
     16 # http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
     17 # http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
     18 
     19 
     20 url_base = "http://yann.lecun.com/exdb/mnist/"
     21 key_file = {
     22         'train_img':'train-images-idx3-ubyte.gz',
     23         'train_label':'train-labels-idx1-ubyte.gz',
     24         'test_img':'t10k-images-idx3-ubyte.gz',
     25         'test_label':'t10k-labels-idx1-ubyte.gz'
     26         }
     27 
     28 dataset_dir=os.path.dirname(os.path.abspath(__file__))
     29 save_file=dataset_dir + "/mnist.pkl"
     30 
     31 train_num = 60000;
     32 test_num  = 10000;
     33 img_dim   = (1, 28, 28)
     34 img_size  = 28*28;
     35 
     36 
     37 def _download(file_name):
     38     """
     39     :param file_name: 下载mnist的文件
     40     :return: null
     41     """
     42     file_path = os.path.join(dataset_dir, file_name)
     43 
     44     if os.path.exists(file_path):
     45         return
     46 
     47     print("downloading"+file_name+ "...")
     48     urllib.request.urlretrieve(url_base + file_name , file_path)
     49     print("Done.")
     50 
     51 def download_mnist():
     52     """
     53 
     54     :return:
     55     """
     56     for file_name in key_file.values():
     57         _download(file_name);
     58 
     59 def _load_label(file_name):
     60     """
     61     解析标签
     62     :param file_name:
     63     :return:
     64     """
     65     file_path = dataset_dir+'/'+ file_name
     66 
     67     print("converting "+file_name+" to numpy Array.")
     68     with gzip.open(file_path) as f:
     69         labels = np.frombuffer(f.read(), np.uint8, offset=8)
     70     print("Done")
     71 
     72     return labels
     73 
     74 def _load_img(file_name):
     75     """
     76     解析 压缩的图片
     77     :param file_name:
     78     :return:
     79     """
     80     file_path = dataset_dir +'/' + file_name
     81 
     82     print("converting "+ file_name + "to numpy Array")
     83     with gzip.open(file_path) as f:
     84         data = np.frombuffer(f.read(), np.uint8, offset=16) # 16*8=
     85     data = data.reshape(-1, img_size) # N, (W*H*C)=[N,28*28*1]
     86     print("Done")
     87 
     88     return data
     89 
     90 def _convert_numpy():
     91     """
     92      解析 image和label,将其转换为numpy
     93     """
     94     dataset = {}
     95     dataset['train_img'] = _load_img(key_file['train_img'])
     96     dataset['train_label'] = _load_label(key_file['train_label'])
     97     dataset['test_img'] = _load_img(key_file['test_img'])
     98     dataset['test_label'] = _load_label(key_file['test_label'])
     99 
    100     return dataset
    101 
    102 def init_mnist():
    103     """
    104     初始化mnist数据集:
    105     1. 下载mnist,
    106     2. 以二进制的方式读取,并转换成numpy的ndarray对象
    107     3. 将转换后的ndarray 序列化
    108 
    109     :return:
    110     """
    111     print("download mnist dataset...")
    112     download_mnist()
    113     print("convert to numpy array...")
    114     dataset = _convert_numpy()
    115     print("creating pickle file ...")
    116     with open(save_file, 'wb') as f:
    117         pickle.dump(dataset, f, -1)
    118     print("Done!")
    119 
    120 def _change_one_hot_label(Y):
    121     T = np.zeros((Y.size,10))
    122     for idx,row in enumerate(T):
    123         row[Y[idx]] = 1
    124     return T
    125 
    126 def load_mnist(normalize=True, flatten=True, one_hot_label=False):
    127     """
    128 
    129     :param normalize: 将数据标准化到0.0~1.0
    130     :param flatten: 是否要将数据拉伸层1D数组的形式
    131     :param one_hot_label:
    132     :return: (训练数据, 训练标签), (测试数据, 测试label)
    133     """
    134 
    135 
    136     if not os.path.exists(save_file):
    137         init_mnist()
    138 
    139     with open(save_file,'rb') as f:
    140         dataset = pickle.load(f)
    141 
    142     if normalize:
    143         for key in ('train_img','test_img'):
    144             dataset[key] = dataset[key].astype(np.float32)
    145             dataset[key] /=255.0
    146     if one_hot_label:
    147         dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
    148         dataset['test_label']  = _change_one_hot_label(dataset['test_label'])
    149 
    150     if not flatten:
    151         for key in ('train_img', 'test_img'):
    152             dataset[key] = dataset[key].reshape(-1,1,28,28) # NCHW
    153 
    154     return (dataset['train_img'],dataset['train_label']),(dataset['test_img'], dataset['test_label'])
    155 
    156 if __name__ == '__main__':
    157     init_mnist()
    158 



    测试


      1 # -*- coding: utf-8 -*-
      2 # @File  : show_mnist.py
      3 # @Author: lizhen
      4 # @Date  : 2020/1/27
      5 # @Desc  : 显示图片
      6 
      7 from src.datasets.mnist import load_mnist
      8 
      9 from skimage import io
     10 
     11 
     12 def img_show(data):
     13     # pil_img = Image.fromarray(np.uint8(data))
     14     io.imshow(data)
     15     io.show()
     16 
     17 
     18 (x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
     19 img = x_train[0]
     20 label = t_train[0]
     21 print(label)
     22 
     23 print(img.shape)
     24 img = img.reshape(28,28)
     25 print(img.shape)
     26 
     27 img_show(img)
     28 


    image


    2020年2月11日 修改

  • 相关阅读:
    不知如何摧毁Kendo UI for jQuery小部件?这份指南不得不看
    MyEclipse导航代码第二弹,Java开发更便捷
    索引扫描与索引查找区别
    Chrome使用技巧
    什么是中台?所有的中台都是业务中台
    跨域资源共享CORS详解
    多线程之入门起步(三)
    聊天程序——基于Socket、Thread (二)
    多线程的相关概念(一)
    使用BCP实用工具导出导入数据
  • 原文地址:https://www.cnblogs.com/greentomlee/p/12314915.html
Copyright © 2011-2022 走看看