zoukankan      html  css  js  c++  java
  • SOM自组织映射网络 教程

    概述

    SOM是芬兰教授Teuvo Kohonen提出的一种神经网络算法,它提供一种将高维数据在低维空间进行表示的方法(通常是一维或二维)。缩减向量维度的过程,叫做向量量化(vector quantisation)。此外,SOM网络能保留原有数据的拓扑关系。

    一个用来直观感受SOM网络规则的例子,是将3维颜色映射到二维空间,如图所示。

    图1

    左图的颜色是按(r,g,b)组合形式表示的,SOM网络经过学习后能他这些颜色在二维空间进行表示。如右图所示:为了让颜色聚类,相似的属性通常被发现是相邻的。这种特性被很好的使用,后面你还会看到的。

    SOM最有趣的一个方面是,它是无监督学习的。你可能已经知道有监督训练,比如BP神经网络,它的训练数据是由(input, target)向量元组组成的。用这种有监督的方法,当给定输入向量到网络中(典型的是一个多层前馈网络),输出会被用来和目标向量比较,如果它们不相同,则微调网络中的权重以减小输出误差。这一过程重复多次,并且是对于所有(input,target)向量元组进行的,直到网络给出想要的输出。然而训练一个SOM网络,则不需要目标输出向量target。SOM网络学着对训练数据进行分类而不需要任何外部监督。是不是很神奇?

    在说出事实真相之前,你最好忘掉你对神经网络的任何知识!如果你尝试着用神经元、激活函数、前馈/反馈连接这些术语来思考SOM,你很可能很快就懵了。所以,把脑中那些知识先丢到一边去吧!

    样例代码下载(C++版):http://www.ai-junkie.com/files/SOMDemo.zip

    ZIP包中包括了可执行程序,没有编译器也可以玩的!

    另外有人提供一个java版本的:http://www.ai-junkie.com/files/SOMDemo_java.zip

    网络架构

    本教程使用2维的SOM网络。网络从2维网格节点创建,每个节点和输入层是全连接的。(译注:下述如未特别声明,“节点”均表示“竞争层节点”)下图是一个有4*4个节点、和输入层全连接的网络架构,表示了一个二维的SOM网络:

    图2

    每个节点都有一个拓扑位置,即节点中的(x,y)坐标,同时也包含一个和输入向量同维度的权重向量。也就是说,如果训练数据是由向量n维向量V组成的:

    V1,V2,V3,…,V

    那么每个节点都会包含一个相应的n维权重向量W:

    W1,W2,W3,…Wn

    上图中连接节点的连线,仅仅是用来表示邻接,并不意味着通常谈论的神经网络中的一个连接。网格中的节点之间没有侧连接(不懂!前面看到的教程都说是有的欸)。

    图1所示的SOM网络有一个尺寸为40*40的默认网格,网格中的每个节点有三个权重,对应着输入向量中的三个维度:red,green,blue(这里也不懂。为什么是3个权重连接,不应该是8个嘛?)。每个节点以长方形cell的形式表示。图3显示了将cell边界渲染为黑色的cell们,这样能看得更清楚些:

    图3

    学习算法概览

    SOM网络不需要目标输出,这是它和其它种类网络区别的地方。相应地,对于节点权重和输入向量匹配的地方,被选择性地优化为和输入数据所属类别更像。从初始时随机权重分布,经过多次迭代,SOM网络最终形成多个稳定的区域。每个区域都是一个有效的特征分类器,因此你可以把图形输出认为是输入空间的一种特征映射(feature map)。如果你再看一眼图1,会发现相似颜色的区块代表了独立的区域。任何新的、以前没见过的输入向量,一旦它输入到了网络中,将会刺激到具有相似权重向量的节点。

    训练在很多步骤中都出现并迭代多次:

    1. 每个节点的权重被初始化
    2. 输入向量是从训练数据中随机选出的,然后被放到网格中
    3. 每个节点都被检查:用节点的权重和输入向量进行比较,找到最相似的那一个。获胜节点通常叫做最佳匹配单元(Best Matching Unit,BMU)。
    4. BMU的邻域的半径被计算。这个值最开始很大,通常设定为和网格半径相同(fuck,网格半径是什么鬼,到底是多少?),随着迭代次数增加会减小。此半径内任何发现的节点被认为是术语BMU邻域内的。
    5. 每个邻域内的节点的权重被调整,从而使得他们和输入向量更像。距离BMU越近,权重改变就越大。
    6. 回到步骤2,重复执行N次。

    学习算法的细节

    现在来详细看看每一步是怎么做的。我会用适当的代码片段补充我的描述。也希望这些代码对于懂编程的读者能增加理解。

    1. 1.    初始化权重

    训练之前,每个节点的权重一定要被初始化。典型地,这些将被设定为小的归一化的随机值。下面的代码片段中SOM的权值被初始化,因而有0<w<1。节点在CNode类中定义。相关的类为:

    class CNode

    {

    private:

      //this node's weights

      vector<double>      m_dWeights;

      //its position within the lattice

      double              m_dX,

                          m_dY;

      //the edges of this node's cell. Each node, when draw to the client

      //area, is represented as a rectangular cell. The colour of the cell

      //is set to the RGB value its weights represent.

      int                 m_iLeft;

      int                 m_iTop;

      int                 m_iRight;

      int                 m_iBottom;

    public:

      CNode(int lft, int rgt, int top, int bot, int NumWeights):m_iLeft(lft),

                                                                m_iRight(rgt),

                                                                m_iBottom(bot),

                                                                m_iTop(top)

      {

        //initialize the weights to small random variables

        for (int w=0; w<NumWeights; ++w)

        {

          m_dWeights.push_back(RandFloat());

        }

        //calculate the node's center

        m_dX = m_iLeft + (double)(m_iRight - m_iLeft)/2;

        m_dY = m_iTop  + (double)(m_iBottom - m_iTop)/2;

      }

      ...

    };

    可以看出,当创建一个CNode对象时候,权值在构造函数中被自动地初始化。成员变量m_iLeft,m_iRight,m_iTop和m_iBottom仅仅被用来渲染网络到输出区域——每个节点以矩形cell的形式显示,这个矩形cell是由这些值来描述的。如图3所示。

    1. 2.    计算最佳匹配单元BMU

    为了算出BMU,一个方法是遍历所有节点并计算节点的权重向量和当前输入向量的欧氏距离。如果一个节点的权值向量和输入向量最接近,那么这个节点就被标记为BMU。

    欧氏距离计算式为:

    公式1

    其中V表示当前输入向量,W表示节点的权重向量。在代码中,这个等式被翻译为:

    double CNode::GetDistance(const vector<double> &InputVector)

    {

      double distance = 0;

      for (int i=0; i<m_dWeights.size(); ++i)

      {

        distance += (InputVector[i] - m_dWeights[i]) * (InputVector[i] - m_dWeights[i]);

      }

      return sqrt(distance);

    }

    例如,计算红色向量(1,0,0)到一个随机权重向量(0.1,0.4,0.5)的距离:

    distance = sqrt( (1 - 0.1)+ (0 - 0.4)2+ (0 - 0.5)2 )

                   = sqrt( (0.9)+ (-0.4)2+ (-0.5)2 )

                   = sqrt( 0.81 + 0.16+ 0.25 )

                    = sqrt(1.22)

    distance  = 1.106

    1. 3.    计算BMU的局部邻域

    从这里开始,事情变得有趣了!每次迭代,只要BMU被算出了,那么下一步就是计算出哪些其他节点是在BMU的邻域内。所有这些节点在下一步都会更新其权值。那么,怎样做到呢?其实不难…首先算出邻域应当具备的半径,然后就是勾股定理的简单使用,来算出每个节点是否属于这个邻域。图4是简单示例:

    图4

    可以看到,邻域是以BMU为中心(黄色圆球)的圆形区域。绿色箭头表示半径。

    Kohonen学习算法的一个独特特征是,随着迭代次数的增加,邻域会变小,这通过缩减半径做到,比如:

    公式2

    其中希腊字母s0表示t0时间的网格宽度,而希腊字母l表示时间常量。t表示当前时间步骤(迭代次数)。我的代码中把s命名为m_dMapRadius,它在训练结束的时候等于s0。我计算s0的方式为:

    m_dMapRadius = max(constWindowWidth, constWindowHeight)/2;

    l的值依赖于s的取值和算法迭代次数的取值。我把l命名为m_dTimeConstant,计算代码为:

    m_dTimeConstant = m_iNumIterations/log(m_dMapRadius);

    其中m_iNumIterations表示迭代次数。在constants.h中用户可进行修改。

    我们可以使用m_dTimeContant和m_dMapRadius在每次迭代中使用公式2计算邻域半径:

    m_dNeighbourhoodRadius = m_dMapRadius * exp(-(double)m_iIterationCount/m_dTimeConstant);

    图5显示了图4中邻域是如何随着迭代次数的增加而缩减的(假设BMU不变):

    图5

    随着时间推移,邻域会缩减为只剩BMU本身。

    现在我们知道了半径和邻域的变化情况,邻域覆盖到的节点需要调整其权值,调整规则如下所示。

    1. 4.    调整权值

    BMU邻域内的每个节点,包括BMU在内,都要根据下式调整权值向量:

    公式3

    其中t表示迭代次数,L是一个小变量,称为学习率,会随着时间推移而减小。这个调整公式的意思是,新的权值,是在老的权值基础上,加上学习率乘以老的权值向量与输入向量的差值。

    学习率的衰减,通过如下公式计算:

    公式4

    其实很容易发现公式4和公式2是相同的形式。对应的代码为:

    m_dLearningRate = constStartLearningRate * exp(-(double)m_iIterationCount/m_iNumIterations);

    训练一开始时的学习率,constStartLearningRate,在constant.h文件中进行设定了为0.1。然后随着时间递减最终接近0。

    注意!公式3其实是不完全正确的!因为没有考虑到节点到BMU的距离。距离的影响可以认为是符合高斯分布的:

    对此,我们改进公式3,得到:

    公式5

    其中q表示节点到BMU距离的影响力度:

    公式6

    SOM的应用

    SOM常常被用来做可视化的助手,能帮助人类更容易地看到大数据之间的关系。我举个栗子:

    世界贫困图

    SOM被用来为统计数据做分类,包括关于生活质量的种种数据,比如健康状态、营养状态、教育服务程度等。具备相似的生活质量的国家被聚集在一起。左上方的国家是生活质量高的国家,而最穷的国家在图中右下方。六边形网格是一个统一距离矩阵,通常叫做u-matrix。每个六边形代表SOM网络中的一个节点。

    接下来可以把颜色信息绘制到地图上,比如:

    这使得我们对于贫困数据的理解更容易了。

    其他应用

    SOM在许多其他方面被应用。比如:

    书目分类

    图像浏览系统

    医学诊断

    解释地震活动

    语音识别(Kohonen最初就是要做这个的)

    数据压缩

    分离声源

    环境模型

    甚至是吸血鬼分类!

  • 相关阅读:
    vue项目webpack配置terser-webpack-plugin 去掉项目中多余的debugger
    difference between count(1) and count(*)
    为什么PostgreSQL WAL归档很慢
    mysql_reset_connection()
    Oracle使用audit跟踪登录失败的连接信息
    .NET Standard 版本
    ASP.NET Web API版本
    我是如何用go-zero 实现一个中台系统的
    JAVA中文件写入的6种方法
    MySql 常用语句
  • 原文地址:https://www.cnblogs.com/zjutzz/p/5059131.html
Copyright © 2011-2022 走看看