zoukankan      html  css  js  c++  java
  • 逻辑回归-二分类

    class 二分类Demo
        {
            public static void Demo()
            {
                List<float[]> inputs_x = new List<float[]>();
                inputs_x.Add(new float[] { 0.9f, 0.6f });
                inputs_x.Add(new float[] { 2f, 2.5f });
                inputs_x.Add(new float[] { 2.6f, 2.3f });
                inputs_x.Add(new float[] { 2.7f, 1.9f });
    
                List<float> inputs_y = new List<float>();
                inputs_y.Add(0);
                inputs_y.Add(1);
                inputs_y.Add(1);
                inputs_y.Add(1);
    
                float[] weights = new float[inputs_x.First().Length+1];//加上b
                for (var i = 0; i < weights.Length; i++)
                    weights[i] = (float)new Random().NextDouble();
    
                int epoch = 30000;
                float epsilon = 0.00000000000000001f;
                float lr = 0.01f;
    
                float lastCost = 0;
    
                for (var epoch_i = 0; epoch_i <= epoch; epoch_i++)
                {
                    //随机获取input
                    var batch = GetRandomBatch(inputs_x, inputs_y, 2);
    
                    float[] weights_in_poch = new float[weights.Length];
    
                    foreach (var x_y in batch)
                    {
                        var x1 = x_y.Item1.First();
                        var x2 = x_y.Item1.Skip(1).Take(1).First();
                        var target_y = x_y.Item2;
    
                        float diffWithTargetY = target_y - Sigmoid(fun(x1, x2, weights[1], weights[2], weights[0]));
    
                        weights_in_poch[0] += diffWithTargetY * dy_b(x1, x2);
                        weights_in_poch[1] += diffWithTargetY * dy_theta1(x1, x2);
                        weights_in_poch[2] += diffWithTargetY * dy_theta2(x1, x2);
                    }
    
                    for (var i = 0; i < weights.Length; i++)
                        weights[i] += lr * weights_in_poch[i];
    
                    float totalErrorCost = 0f;
                    foreach (var x_y in batch)
                    {
                        var x1 = x_y.Item1.First();
                        var x2 = x_y.Item1.Skip(1).Take(1).First();
                        var target_y = x_y.Item2;
    
                        float diffWithTargetY = target_y - Sigmoid(fun(x1, x2, weights[1], weights[2], weights[0]));
                        totalErrorCost += (float)System.Math.Pow(diffWithTargetY, 2) / 2;
                    }
    
                    float cost = totalErrorCost / batch.Count;
    
                    if (System.Math.Abs(cost - lastCost) <= epsilon)
                    {
                        Console.WriteLine(string.Format("EPOCH {0}", epoch_i));
                        Console.WriteLine(string.Format("LAST MSE {0}", lastCost));
                        Console.WriteLine(string.Format("MSE {0}", cost));
                        break;
                    }
    
                    lastCost = cost;
    
                    if (epoch_i % 100 == 0 || epoch_i == epoch)
                    {
                        Console.WriteLine(string.Format("MSE {0}", cost));
                    }
                }
    
                print(weights[1], weights[2], weights[0]);
    
                Console.ReadLine();
            }
    
            private static List<Tuple<float[], float>> GetRandomBatch(List<float[]> inputs_x, List<float> inputs_y, int maxCount)
            {
                List<Tuple<float[], float>> lst = new List<Tuple<float[], float>>();
    
                System.Random rnd = new Random((int)DateTime.Now.Ticks);
    
                int count = 0;
                while (count < maxCount)
                {
                    int rndIndex = rnd.Next(inputs_x.Count);
                    var item = Tuple.Create<float[], float>(inputs_x[rndIndex], inputs_y[rndIndex]);
                    lst.Add(item);
                    count++;
                }
    
                return lst;
            }
    
            private static void print(float theta1, float theta2, float b)
            {
                Console.WriteLine(string.Format("y=sigmoid({0}*x1+{1}*x2+{2})", theta1, theta2, b));
            }
    
            private static float Sigmoid(double x)
            {
                double one = 1;
                var result= one / (one + System.Math.Exp(-x));
                return (float)result;
            }
    
            private static float fun(float x1, float x2, float theta1, float theta2, float b)
            {
                return theta1 * x1 + theta2 * x2 + b;
            }
            private static float dy_theta1(float x1, float x2)
            {
                return x1;
            }
    
            private static float dy_theta2(float x1, float x2)
            {
                return x2;
            }
    
            private static float dy_b(float x1, float x2)
            {
                return 1;
            }
        }
    import matplotlib.pyplot as plt
    import numpy as np
    
    x1=np.array([0.9,2,2.6,2.7], dtype=np.float)
    x2=np.array([0.6,2.5,2.3,1.9], dtype=np.float)
    
    
    def sigmoid(v):
        return 1/ (1 + np.exp(-v))
    
    #y=sigmoid(0.8277054*x1+0.893053*x2+-0.7201675)
    #y=sigmoid(1.968242*x1+4.206787*x2+-7.930489)
    def y_function(x1, x2):
        #return sigmoid(0.8277054*x1+0.893053*x2-0.7201675)
        return sigmoid(1.968242*x1+4.206787*x2-7.930489)
    
    y=y_function(x1, x2)
    
    for index, y_value in enumerate(y):
        if(y_value>0.5):
            plt.scatter([x1[index]], [x2[index]], c = 'red',marker = 'o')
        else:
            plt.scatter([x1[index]], [x2[index]], c = 'blue',marker = 'o')
    
    plt.show()

     

    3D图形:

    import matplotlib.pyplot as plt
    import numpy as np
    from mpl_toolkits.mplot3d import Axes3D
    
    def sigmoid(v):
        return 1/ (1 + np.exp(-v))
    
    x1=np.linspace(1, 5, 100)
    x2=np.linspace(1, 5, 100)
    y=1.968242*x1+4.206787*x2-7.930489
    y_s=sigmoid(y)
    
    
    fig = plt.figure()
    ax = Axes3D(fig)
    ax.scatter(x1, x2, y_s, c='r', marker='.', s=50, label='')
    plt.show()

  • 相关阅读:
    Error -26631: HTTP Status-Code=400 (Bad Request) for
    mysql中的制表符替换
    mysql中json数据的拼接方式
    使用Nightwatch.js做基于浏览器的web应用自动测试
    Selenium + Nightwatch 自动化测试环境搭建
    Python web 框架:web.py
    转 Python Selenium设计模式-POM
    自动化测试
    日志打印longging模块(控制台和文件同时输出)
    读取配置文件(configparser,.ini文件)
  • 原文地址:https://www.cnblogs.com/aarond/p/7966592.html
Copyright © 2011-2022 走看看