zoukankan      html  css  js  c++  java
  • SGD

    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Text;
    using System.Threading.Tasks;
    
    namespace ConsoleApp4
    {
        class Program
        {
            static void Main(string[] args)
            {
                List<float[]> inputs_x = new List<float[]>();
                inputs_x.Add( new float[] { 0.9f, 0.6f});
                inputs_x.Add(new float[] { 2f, 2.5f } );
                inputs_x.Add(new float[] { 2.6f, 2.3f });
                inputs_x.Add(new float[] { 2.7f, 1.9f });
    
                List<float> inputs_y = new List<float>();
                inputs_y.Add( 2.5f);
                inputs_y.Add( 2.5f);
                inputs_y.Add( 3.5f);
                inputs_y.Add( 4.2f);
    
                float[] weights = new float[3];
                for (var i= 0;i < weights.Length;i++)
                    weights[i] = (float)new Random().NextDouble();
    
                int epoch = 30000;
                float epsilon =0.00001f;
                float lr = 0.01f;
    
                float lastCost=0;
    
                for (var epoch_i = 0; epoch_i <= epoch; epoch_i++)
                {
                    //随机获取input
                    var batch = GetRandomBatch(inputs_x, inputs_y, 2);
    
                    float[] weights_in_poch = new float[weights.Length];
    
                    foreach (var x_y in batch)
                    {
                        var x1 = x_y.Item1.First();
                        var x2 = x_y.Item1.Skip(1).Take(1).First();
                        var target_y = x_y.Item2;
    
                        float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
    
                        weights_in_poch[0] += diffWithTargetY * dy_b(x1, x2);
                        weights_in_poch[1] +=  diffWithTargetY * dy_theta1(x1, x2);
                        weights_in_poch[2] += diffWithTargetY * dy_theta2(x1, x2);
                    }
    
                    for(var i=0;i<weights.Length;i++)
                        weights[i] += lr * weights_in_poch[i];
    
                    float totalErrorCost = 0f;
                    foreach (var x_y in batch)
                    {
                        var x1 = x_y.Item1.First();
                        var x2 = x_y.Item1.Skip(1).Take(1).First();
                        var target_y = x_y.Item2;
    
                        float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
                        totalErrorCost += (float)System.Math.Pow(diffWithTargetY, 2)/2;
                    }
    
                    float cost = totalErrorCost / batch.Count;
    
                    if (System.Math.Abs(cost - lastCost) <= epsilon)
                    {
                        Console.WriteLine(string.Format("EPOCH {0}", epoch_i));
                        Console.WriteLine(string.Format("LAST MSE {0}", lastCost));
                        Console.WriteLine(string.Format("MSE {0}", cost));
                        break;
                    }
    
                    lastCost = cost;
    
                    if (epoch_i % 100 == 0|| epoch_i==epoch)
                    {
                        Console.WriteLine(string.Format("MSE {0}", cost));
                    }
                }
    
                print(weights[1], weights[2], weights[0]);
    
                Console.ReadLine();
            }
    
            private static List<Tuple<float[], float>> GetRandomBatch(List<float[]> inputs_x, List<float> inputs_y, int maxCount)
            {
                List<Tuple<float[], float>> lst = new List<Tuple<float[], float>>();
    
                System.Random rnd = new Random((int)DateTime.Now.Ticks);
    
                int count = 0;
                while (count<maxCount)
                {
                    int rndIndex = rnd.Next(inputs_x.Count);
                    var item=Tuple.Create<float[], float>(inputs_x[rndIndex], inputs_y[rndIndex]);
                    lst.Add(item);
                    count++;
                }
    
                return lst;
            }
    
            private static void print(float theta1, float theta2, float b)
            {
                Console.WriteLine(string.Format("y={0}*x1+{1}*x2+{2}", theta1, theta2, b));
            }
            private static float fun(float x1, float x2, float theta1, float theta2, float b)
            {
                return theta1 * x1 + theta2 * x2 + b;
            }
            private static float dy_theta1(float x1, float x2)
            {
                return x1;
            }
    
            private static float dy_theta2(float x1, float x2)
            {
                return x2;
            }
    
            private static float dy_b(float x1, float x2)
            {
                return 1;
            }
        }
    }
    

      

  • 相关阅读:
    English,The Da Vinci Code, Chapter 23
    python,meatobject
    English,The Da Vinci Code, Chapter 22
    English,The Da Vinci Code, Chapter 21
    English,The Da Vinci Code, Chapter 20
    English,The Da Vinci Code, Chapter 19
    python,xml,ELement Tree
    English,The Da Vinci Code, Chapter 18
    English,The Da Vinci Code, Chapter 17
    English,The Da Vinci Code, Chapter 16
  • 原文地址:https://www.cnblogs.com/aarond/p/7936523.html
Copyright © 2011-2022 走看看