zoukankan      html  css  js  c++  java
  • 感知机算法的两种表示

     

    感知机算法的原始形式

    输入:训练数据集T={(x1,y1),(x2,y2),...,(xn,yn)},其中xi属于Rn(n维空间向量),yi={-1,+1},i=1,2,...,N

    学习率t(0<t<=1);

    输出:w,b;感知机模型f(x)=sign(w*x+b)

    (1)选取初值w0,b0一般为 0

    (2)在训练集中选取数据(xi,yi)

    (3)如果yi(w*xi+b)<=0

                              w<---w+tyixi

                               b<----b+tyi

    (4)转至(2),直至训练集中没有误分类点。

    算法代码如下

     1 import java.util.Random;
     2 import java.util.Scanner;
     3 import java.util.regex.Matcher;
     4 import java.util.regex.Pattern;
     5 
     6 //初始的感知机学习算法
     7 public class ganzhijiOriginal {
     8     //    数据集中数据的个数
     9     public static int NCount;
    10     //    数据中每个数值的维数 不包含y
    11     public static int N;
    12     //    学习率  (0,1]
    13     public static float t;
    14 
    15     //    保存数据
    16     public static int datas[][];
    17     //    权重向量
    18     public static float w[];
    19     //   偏移数据
    20     public static float b;
    21 
    22     public static void main(String args[]) {
    23         Scanner sca = new Scanner(System.in);
    24 //       默认输入的格式为 N-t-b
    25         System.out.println("输入格式为N-t-b-NCount的数据:");
    26         String lines = sca.nextLine();
    27         Pattern pattern = Pattern.compile("(.*)-(.*)-(.*)-(.*)");
    28         Matcher matcher = pattern.matcher(lines);
    29         matcher.find();
    30         N = Integer.parseInt(matcher.group(1));
    31         t = Float.parseFloat(matcher.group(2));
    32         b = Float.parseFloat(matcher.group(3));
    33 //        b = Integer.parseInt(matcher.group(3));
    34         NCount = Integer.parseInt(matcher.group(4));
    35 //        System.out.println("请输入权重向量的初始值格式为格式为x,x,x,x" + N);
    36 //初始化数据集
    37         w = new float[N];
    38 //        lines = sca.nextLine();
    39 //        String[] str = lines.split(",");
    40 //        for (int i = 0; i < N; i++)
    41 //            w[i] = Float.parseFloat(str[i]);
    42         datas = new int[NCount][N + 1];
    43         System.out.println("请输入所有权重向量的初始值");
    44         for (int i = 0; i < NCount; i++) {
    45             String line = sca.nextLine();
    46             String strs[] = line.split(" ");
    47             for (int j = 0; j <= N; j++)
    48                 datas[i][j] = Integer.parseInt(strs[j]);
    49         }
    50         CountTheTimes = 0;
    51         Random ra = new Random();
    52         int chooseNumber = ra.nextInt(NCount);
    53         CalcuteAndUpdatValue(0);
    54         String strd = "";
    55         for (int i = 0; i < N; i++)
    56             strd += w[i] + "*x" + i + "+ ";
    57         System.out.println("F(x)=sign(" + strd + b + ")");
    58     }
    59 
    60     public static int CountTheTimes;
    61 
    62     private static void CalcuteAndUpdatValue(int chooseNumber) {
    63         float f = isPOrN(chooseNumber);
    64         boolean flages = f * datas[chooseNumber][N] > 0 ? true : false;
    65         if (!flages) {
    66             for (int j = 0; j < N; j++) //更新权重w的值
    67                 w[j] = w[j] + t * datas[chooseNumber][N] * datas[chooseNumber][j];
    68             b = b + t * datas[chooseNumber][N];
    69 
    70             CountTheTimes = 0;//初始化
    71         } else {
    72             CountTheTimes++;
    73             chooseNumber = (chooseNumber + 1) % NCount;
    74         }
    75         if (CountTheTimes == NCount) return;
    76         CalcuteAndUpdatValue(chooseNumber);
    77     }
    78 
    79     private static float isPOrN(int chooseNumber) {
    80         float sum = 0;
    81         for (int i = 0; i < N; i++) sum += datas[chooseNumber][i] * w[i];
    82         return sum + b;
    83     }
    84 
    85 }
    View Code

    感知机算法的对偶形式

    输入:线性可分的数据集T={(x1,y1),(x2,y2),...,(xn,yn)}其中xi属于Rn(n维向量),yi属于{-1,+1},i,2,。。。,N;学习率为t (0<t<=1)

    输出a,b 感知机模型f(x)=sign(j从1 到 ajyjxj*x累加  +b) 其中a=(a1,a2,a3...,an)T

    (1)a<---0,b<----0

    (2)在训练数据集中选择(xi,yi)

    (3)如果 yi(j从1 到 ajyjxj*xi累加  +b)<=0

                      ai<----ai+t

                     b<---b+tyi

    (4)转至(2)直到没有误分类点

    算法如下

    import java.util.Random;
    import java.util.Scanner;
    
    //感知机算法的对偶形式
    public class ganzhijiOudui {
        //存储数据的Gram矩阵
        public static int gramMatrix[][];
        //    存储初始点集合包括y
        public static int datas[][];
        //感知机中的学习率(0,1]
        public static float t;
        //    存储某个点更新的次数nt
        public static float a[];
        //    点的维度 不包含y
        public static int N;
        //    总点的数
        public static int NCount;
        //偏移量
        public static float b;
    
        public static int w[];//权重
    
        public static void main(String args[]) {
            Scanner sca = new Scanner(System.in);
            System.out.println("输入的格式为:N-t-b-NCount");
            String line = sca.nextLine();
            String dt[] = line.split("-");
            N = Integer.parseInt(dt[0]);
            t = Float.parseFloat(dt[1]);
            b = Float.parseFloat(dt[2]);
            NCount = Integer.parseInt(dt[3]);
            InitDatas(sca);
    
    
        }
    
        private static void InitDatas(Scanner sca) {
            datas = new int[NCount][N + 1];
            gramMatrix = new int[NCount][NCount];
            w = new int[N];
    //        默认值为零
            a = new float[NCount];
            System.out.println("输入数据点集合格式为x x x ... Y");
            for (int i = 0; i < NCount; i++) {
                String line = sca.nextLine();
                String strs[] = line.split(" ");
                for (int j = 0; j <= N; j++) datas[i][j] = Integer.parseInt(strs[j]);
            }
            System.out.println("初始化Gram矩阵");
            for (int i = 0; i < NCount; i++) {
                for (int j = 0; j < NCount; j++) {
                    gramMatrix[i][j] = MultiplyTheDatas(i, j);
                }
            }
    //        记录总的循环次数
            TotalTimes = 0;
            Random random = new Random();
            int chooseNumber = random.nextInt(NCount);
            CalculateAndUpdatDatas(0);
            //更新 w
            for (int i = 0; i < N; i++) {
                int sum = 0;
                for (int j = 0; j < NCount; j++) sum += a[j] * datas[j][i]*datas[j][N];
                w[i] = sum;
            }
            String strd = "";
            for (int i = 0; i < N; i++)
                strd += w[i] + "*x" + i + "+ ";
            System.out.println("F(x)=sign(" + strd + b + ")");
        }
    
        public static int TotalTimes = 0;
    
        private static void CalculateAndUpdatDatas(int chooseNumber) {
    
            float sum = CaluteTheFx(chooseNumber);
            boolean flages = sum * datas[chooseNumber][N] > 0 ? true : false;
            if (!flages) {
                TotalTimes = 0;
                a[chooseNumber] = a[chooseNumber] + t;
                b = b + t * datas[chooseNumber][N];
            } else {
                TotalTimes++;
                chooseNumber = (chooseNumber + 1) % NCount;
            }
            if (TotalTimes == NCount) return;
            CalculateAndUpdatDatas(chooseNumber);
        }
    
        private static float CaluteTheFx(int chooseNumber) {
    
            float sum = 0;
            for (int i = 0; i <=N; i++) sum += a[i] * datas[i][N] * gramMatrix[i][chooseNumber];
            return sum + b;
        }
    
    
        //    计算两个向量的乘积
        private static int MultiplyTheDatas(int i, int j) {
            int sum = 0;
            for (int k = 0; k <N; k++) sum += datas[i][k] * datas[j][k];
            return sum;
        }
    
    
    }
    View Code
  • 相关阅读:
    早晨突然想到的几句话
    VBA-工程-找不到工程或库-解决方案
    Mysql 服务无法启动 服务没有报告任何错误
    一道有趣的面试题
    异步和多线程
    异或运算
    线性代数解惑
    全文搜索引擎 Elasticsearch (一)
    HandlerExceptionResolver统一异常处理 返回JSON 和 ModelAndView
    MySQL 20个经典面试题
  • 原文地址:https://www.cnblogs.com/09120912zhang/p/7682033.html
Copyright © 2011-2022 走看看