zoukankan      html  css  js  c++  java
  • demo1

    写了个demo,只有两个文件。 

    定义了一个叫 Square 的类,相当于pytorch中的 Function,包含了forward和backward方法。实现 y = a*x^2。

    Square.m

     1 classdef Square < handle
     2     properties
     3         input;
     4         a;
     5         
     6         grad_input;
     7         d_a;
     8     end
     9     
    10     methods
    11         function self = Square(a)
    12             self.a = a;
    13             self.grad_input = 0;
    14             self.d_a = 0;
    15         end
    16         
    17         function out = forward(self, input)
    18             out = self.a*input.^2; % y = a*x^2  这个相当于model
    19             self.input = input;
    20         end
    21         
    22         function grad_input = backward(self, grad_output)
    23             % dy = 2*a*dx 以x为变量,根据dy 求 dx   
    24             % dy = 2*x*da 以待训练权重a为变量,根据dy 求 d_a,不求出d_a来,
    25             % a_iter是没法更新的。
    26             % 
    27             % 之前没有自己实现前后向的计算,真的没意识到要求两次微分。
    28             grad_input = grad_output/(2*self.a);
    29             self.d_a = grad_output/(2*self.input);
    30         end
    31     
    32     end
    33 end
    View Code

    然后是main.m

     1 clear all;clc;
     2 % y = a*x^2 只有一个待求参数a,所以只需要一个样本点(x, label)就可以求出a
     3 
     4 x = 3;
     5 a = 4;
     6 label = a*x^2;
     7 
     8 a_iter = rand();     % 随机初始化需要训练的参数 a
     9 sq = Square(a_iter); % new 一个 model 出来
    10 
    11 loss_tb = zeros(1, 100000);  % loss table 用来记录误差
    12 
    13 lr = 0.0001; % learning rate
    14 
    15 for i = 1:100000          % epoch = 100000
    16     out = sq.forward(x);  % 前向计算
    17     loss = (label-out)^2; % criterion, 简单的用差的平方做loss
    18 %     loss = (label-out);
    19     loss_tb(i) = loss;
    20     
    21     dx1 = sq.backward(loss);  
    22     % 反向计算,因为只有一个环节,所以这里dx1后面不再需要用到。
    23     % 只是为了跑一下backward方法,计算下sq.d_a
    24     
    25     a_iter = a_iter + lr*sq.d_a;  % 相当于 optimizer
    26     sq.a = a_iter;                % 更新模型参数
    27 end
    28 
    29 a_iter  
    30 % 总是等于3.9926,与4有点差距,可能是数值计算方面的原因。
    31 % 上面如果 loss = label - out,反而能精确的等于4,但是误差下降的更慢
    32 
    33 plot(log(1+loss_tb))
    View Code

    这个demo基本上体现了pytorch的设计思路。

    写完这个demo自然而然就想到了batchSize的问题,意识到了为什么pytorch要在model里面

    搞出一个Parameter类出来。

    =========2018年1月29日17:23:34=========

    这两天一直在想怎么实现DAG,试了下,还是比较麻烦的。

    matconvnet本身用的结构体,pytorch在ATen里用了三个list,darknet用的双向链表。

    索性不实现DAG了,手动把模块装起来搞前后向。

    自己不会cuda编程,奔着GPU coder以及内置的一些这两年的神经网络模型,

    装了个matlab R2017b。看看能不能通过学习一些基本的cuda编程知识,自己实现个gpu上的

    卷积运算。

    =========2018年2月1日23:26:21=========

    matconvnet的 DAGNN里头什么都有,已经有模块化的forward,backward了。

    simplenn可读性太差。

    =========2018年2月4日00:01:47=========

    决定花点时间研究下matconvnet了,因为仔细一看发现了这个:

    https://github.com/vlfeat/matconvnet-contrib

    =========2018年2月4日00:47:35=========

    ceres solver中有个求自动微分的功能:

    https://github.com/ceres-solver/ceres-solver/blob/c2da96082b7ea4d6cdcb1ca83a3e84264156ea48/include/ceres/jet.h

    =========2018年2月4日14:46:52=========

    https://github.com/necroen/toy_lenet

    =========2018年2月10日13:27:37=========

  • 相关阅读:
    #define #undef
    ps
    Find–atime –ctime –mtime的用法与区别总结
    redis
    linux mutex
    private继承
    boost::noncopyable介绍
    Makefile 中:= ?= += =的区别
    linux Tar 命令参数详解
    Ubuntu14.04安装CMake3.0.2
  • 原文地址:https://www.cnblogs.com/shepherd2015/p/8378394.html
Copyright © 2011-2022 走看看