zoukankan      html  css  js  c++  java
  • 如何用C++编写不到200行的神经网络

    目前已经有许多现成的深度学习框架,为什么我们还要用C++来编写一个神经网络?一个理由是我们需要了解学习框架内部的运行原理,当分析问题的时候能够很快的定位原因;另一个理由是,我们需要为专有设备编写一个推理引擎,它可能运行在手机端,或者移动设备上。这篇文章实现了一个最简单的神经网络框架,适合大家入门学习。

    参考代码链接:https://github.com/webbery/MiniEngine/tree/demo

    一个最简单的神经网络包含输入节点、一个线性层、一个激活函数和一个损失函数。为了减少编码,我们采用Eigen代替numpy,来实现相关的矩阵操作。

    首先定义一个Node节点类作为基类,它包含一个输入节点的队列和一个输出节点的队列;以及当前节点的值,与它连接的反向传播的梯度值,前向传播接口和反向传播接口。另外为了调试方便,我们给Node增加了一个name成员变量,来标识数据对应到哪个节点。

     1 class Node {
     2 public:
     3   virtual void forward() = 0;
     4   virtual void backward() = 0;
     5 protected:
     6   Eigen::MatrixXf _value;
     7   std::vector<Node*> _inputs;
     8   std::vector<Node*> _outputs;
     9   std::map<Node*, Eigen::MatrixXf> _gradients;
    10   std::string _name;
    11 };

      输入节点Input继承Node类,代表输入变量,这些变量将被数据赋值。对应于tensorflowVariable类。这里我们在构造函数里初始化了它的大小rowcol

    1 class Input : public Node {
    2 public:
    3     Input(const char* name,size_t rows=0,size_t cols=0);
    4 };

      Linear节点,代表全连接层,前向传播接口的实现方式为WX+bias,其中biasWX计算出来的矩阵形式是不相同的,需要对bias做一个广播操作;反向传播需要计算输出节点对应的WXbias梯度值

     1 class Linear : public Node {
     2 public:
     3     Linear(Node* nodes, Node* weights, Node* bias);
     4 
     5     virtual void forward(){_value = (_nodes->getValue() * _weights->getValue()).rowwise() + Eigen::VectorXf(_bias->getValue()).transpose();}
     6 
     7     virtual void backward(){
     8     for (auto node : _outputs){
     9       auto grad = node->getGradient(this);
    10       _gradients[_weights] = _nodes->getValue().transpose() * grad;
    11       _gradients[_bias] = grad.colwise().sum().transpose();
    12       _gradients[_nodes] = grad * _weights->getValue().transpose();
    13     }
    14   }
    15 
    16 private:
    17     Node* _nodes = nullptr;
    18     Node* _weights = nullptr;
    19     Node* _bias = nullptr;
    20 };

      Sigmoid节点,代表激活函数,前向传播计算sigmoid函数结果,反向传播计算sigmoid导函数

     1 class Sigmoid : public Node {
     2 public:
     3   Sigmoid(Node* node);
     4   virtual void forward(){_value = _impl(_node->getValue());}
     5   virtual void backward(){
     6     auto y = _value;
     7     auto y2 = y.cwiseProduct(y);
     8     _partial = y-y2;
     9 
    10     for (auto node : _outputs) {
    11       auto grad = node->getGradient(this);
    12       _gradients[_node] = grad.cwiseProduct(_partial);
    13     }
    14   }
    15 private:
    16   Eigen::MatrixXf _impl(const Eigen::MatrixXf& x){return (-x.array().exp() + 1).inverse();}
    17 private:
    18   Node* _node = nullptr;
    19   //sigmoid的偏导
    20   Eigen::MatrixXf _partial;
    21 };

     MSE节点,代表Loss function,前向传播计算估计值与真实值的方差,反向传播需要计算方差的一阶导数结果

     1 class MSE : public Node {
     2 public:
     3   MSE(Node* y, Node* y_hat);
     4   virtual void forward(){
     5     _diff = _y->getValue() - _y_hat->getValue();
     6     auto diff2= _diff.cwiseProduct(_diff);
     7     auto v = Eigen::MatrixXf(1, 1);
     8     v << (diff2).mean();
     9     _value = v;
    10   }
    11   virtual void backward(){
    12     auto r = _y_hat->getValue().rows();
    13     _gradients[_y] = _diff * (2.f / r);
    14     _gradients[_y_hat] = _diff * (-2.f / r);
    15   }
    16 private:
    17   Node* _y;
    18   Node* _y_hat;
    19   Eigen::MatrixXf _diff;
    20 };

     这几个节点是我们将要构建的框架的基本元素。然后我们还需要实现一个图的拓扑排序,对排序后的节点进行前向迭代,计算预测结果;然后再反向迭代,计算每个连接的梯度。

     1 std::vector<Node*> topological_sort(Node* input_nodes){
     2   //根据传入的数据初始化图结构
     3   Node* pNode = nullptr;
     4   //pair第一个为输入,第二个为输出
     5   std::map < Node*, std::pair<std::set<Node*>, std::set<Node*> > > g;
     6   //待遍历的周围节点
     7   std::list<Node*> vNodes;
     8   vNodes.emplace_back(input_nodes);
     9   //广度遍历,先遍历输出节点,再遍历输入节点
    10   //已经遍历过的节点
    11   std::set<Node*> sVisited;
    12   while (vNodes.size() && (pNode = vNodes.front())) {
    13     if (sVisited.find(pNode) != sVisited.end()) vNodes.pop_front();
    14     const auto& outputs = pNode->getOutputs();
    15     for (auto item: outputs){
    16       g[pNode].second.insert(item);    //添加item为pnode的输出节点
    17       g[item].first.insert(pNode);    //添加pnode为item的输入节点
    18       if(sVisited.find(item)==sVisited.end()) vNodes.emplace_back(item);    //把没有访问过的节点添加到待访问队列中
    19     }
    20     const auto& inputs = pNode->getInputs();
    21     for (auto item: inputs){
    22       g[pNode].first.insert(item);    //添加item为pnode的输入节点
    23       g[item].second.insert(pNode);    //添加pnode为item的输出节点
    24       if (sVisited.find(item) == sVisited.end()) vNodes.emplace_back(item);
    25     }
    26     sVisited.emplace(pNode);
    27     vNodes.pop_front();
    28   }
    29 
    30   //根据图结构进行拓扑排序
    31   std::vector<Node*> vSorted;
    32   while (g.size()) {
    33     for (auto itr=g.begin();itr!=g.end();++itr)
    34     {
    35       //没有输入节点
    36       auto& f = g[itr->first];
    37       if (f.first.size() == 0) {
    38         vSorted.push_back(itr->first);
    39         //找到图中这个节点的输出节点,然后将输出节点对应的这个父节点移除
    40         auto outputs = f.second;//f['out']
    41         for (auto& output: outputs) g[output].first.erase(itr->first);
    42         //然后将这个节点从图中移除
    43         g.erase(itr->first);
    44         break;
    45       }
    46     }
    47   }
    48   return vSorted;
    49 }

     测试程序中,我们定义了每个节点,并构造了节点之间的连接关系;之后把输入节点传给了topological_sort。该函数从输入节点开始,进行广度优先遍历,构建一个图结构;然后根据拓扑排序算法,将这个图结构排序成一个队列返回;这个队列在tensorflow里被称为“图”。

    然后,测试程序运行train_one_batch前向遍历一次得到预测值,然后再反向遍历一次,得到每个节点连接的梯度变化;

    1 void train_one_batch(std::vector<Node*>& graph){
    2   for (auto node:graph){
    3     node->forward();
    4   }
    5   for (int idx = graph.size() - 1; idx >= 0;--idx) {
    6     graph[idx]->backward();
    7   }
    8 }

     接着根据计算出来的梯度值,更新权重节点Wb,完成一次训练

    1 void sgd_update(std::vector<Node*> update_nodes, float learning_rate){
    2   for (auto node: update_nodes){
    3     Eigen::MatrixXf delta = -1 * learning_rate * node->getGradient(node);
    4     node->setValue(node->getValue() + delta);
    5   }
    6 }
  • 相关阅读:
    敏捷个人2011.7月份第一次线下活动报道:迷茫、游戏和故事中的敏捷个人.
    敏捷开发:60分钟掌握敏捷估计和规划
    敏捷之旅北京2011.11月份活动报道:让敏捷落地
    敏捷个人2011.6月份线下活动:拖延、知道力分享
    答TOGAF企业架构的一些问题
    活动推荐:Agile Tour 2011北京站“让敏捷落地”
    敏捷个人2011.5月份线下活动主题一:培养好习惯
    第二届清华大学项目管理精英训练营【敏捷个人】分享
    产品管理:产品的三种驱动类型技术、销售和市场驱动
    敏捷个人线上线下活动PPT及照片做成的视频共享
  • 原文地址:https://www.cnblogs.com/webbery/p/11590451.html
Copyright © 2011-2022 走看看