zoukankan      html  css  js  c++  java
  • 如何使用 libtorch 实现 LeNet 网络?

    如何使用 libtorch 实现 LeNet 网络?

    LeNet 网络论文地址:
    http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf

    LeNet

    C1 卷积层

    {1,1,28,28} 是什么?

    1 输入的批次
    1 图像的通道大小
    28 图像的高
    28 图像的宽

    输入:{1,1,28,28}

    通过填充一个边界 2 ,使得输入变成 {1,1,32,32}

    滑动窗口大小:{5,5}

    输出:{1,6,32,32}

    S2 降采样

    输入:{1,6,32,32}

    滑动窗口大小:{2,2,}
    滑动步长:{2,2}

    输出:{1,6,14,14}

    C3 卷积层

    输入:{1,16,14,14}

    滑动窗口大小:{5,5}

    输出:{1,16,10,10}

    S4 降采样

    输入:{1,16,10,10}

    滑动窗口大小:{2,2,}
    滑动步长:{2,2}

    输出:{1,16,5,5}

    C5 卷积层

    输入:{1,16,5,5}

    滑动窗口大小:{5,5}

    输出:{1,120,1,1}

    F6 全连接层

    这里要把网络形状从 {1,120,1,1} 改变改变成 {1,120}

    第一个全连接
    输入:{1,120}
    输出:{1,84}

    第二个全连接
    输入:{1,84}
    输出:{84,10}

    0~9 总共是 10 个类别嘛,这里就输出 10个就行了。

    全连接就是线性层,网络形状不一样不能全连接的,所以这里要把形状改变成一样的。
    基本按照那图写一遍就明白了。

    关于输入和输出的网络推断公式可以去参考 pytorch 里面的函数说明,上面都有写推断公式滴。

    // Define a new Module.
    struct Net : torch::nn::Module {
    	Net() {
    		conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, /*kernel_size*/{ 5,5 }).padding(/*28->32*/{2,2})));
    		conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, /*kernel_size*/{5,5})));
    		conv3 = register_module("conv3", torch::nn::Conv2d(torch::nn::Conv2dOptions(16, 120, /*kernel_size*/{5,5})));
    		fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(120, 84)));
    		fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(84, 10)));
    	}
    
    	// Implement the Net's algorithm.
    	torch::Tensor forward(torch::Tensor x) {
    		x = conv1->forward(x);//6@28x28
    		x = torch::max_pool2d(x, { 2,2 }, { 2,2 });//6@14x14
    		x = conv2->forward(x);//16@10x10
    		x = torch::max_pool2d(x, { 2,2 }, { 2,2 });//16@10x10
    		
    		x = conv3->forward(x);//120@1x1
    		x = x.view({ x.size(0),-1 });//-1 表示自动推理计算出该值
    		x = fc1->forward(x);//120->84
    		x = fc2->forward(x);//84->10
    		x = torch::log_softmax(x,/*dim=*/1);
    		return x;
    	}
    
    	// Use one of many "standard library" modules.
    	torch::nn::Conv2d conv1 { nullptr };
    	torch::nn::Conv2d conv2 { nullptr };
    	torch::nn::Conv2d conv3 { nullptr };
    	torch::nn::Linear fc1{ nullptr };
    	torch::nn::Linear fc2{ nullptr };
    };
    
  • 相关阅读:
    视图、触发器、事物、存储过程、函数、流程控制
    pymysql
    单表查询与多表查询
    多线程学习(第三天)线程间通信
    多线程学习(第二天)Java内存模型
    多线程学习(第一天)java语言的线程
    springboot集成es7(基于high level client)
    elasticSearch(六)--全文搜索
    elasticSearch(五)--排序
    elasticSearch(四)--结构化查询
  • 原文地址:https://www.cnblogs.com/cheungxiongwei/p/10710968.html
Copyright © 2011-2022 走看看