zoukankan      html  css  js  c++  java
  • SGD

    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Text;
    using System.Threading.Tasks;
    
    namespace ConsoleApp4
    {
        class Program
        {
            static void Main(string[] args)
            {
                List<float[]> inputs_x = new List<float[]>();
                inputs_x.Add( new float[] { 0.9f, 0.6f});
                inputs_x.Add(new float[] { 2f, 2.5f } );
                inputs_x.Add(new float[] { 2.6f, 2.3f });
                inputs_x.Add(new float[] { 2.7f, 1.9f });
    
                List<float> inputs_y = new List<float>();
                inputs_y.Add( 2.5f);
                inputs_y.Add( 2.5f);
                inputs_y.Add( 3.5f);
                inputs_y.Add( 4.2f);
    
                float[] weights = new float[3];
                for (var i= 0;i < weights.Length;i++)
                    weights[i] = (float)new Random().NextDouble();
    
                int epoch = 30000;
                float epsilon =0.00001f;
                float lr = 0.01f;
    
                float lastCost=0;
    
                for (var epoch_i = 0; epoch_i <= epoch; epoch_i++)
                {
                    //随机获取input
                    var batch = GetRandomBatch(inputs_x, inputs_y, 2);
    
                    float[] weights_in_poch = new float[weights.Length];
    
                    foreach (var x_y in batch)
                    {
                        var x1 = x_y.Item1.First();
                        var x2 = x_y.Item1.Skip(1).Take(1).First();
                        var target_y = x_y.Item2;
    
                        float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
    
                        weights_in_poch[0] += diffWithTargetY * dy_b(x1, x2);
                        weights_in_poch[1] +=  diffWithTargetY * dy_theta1(x1, x2);
                        weights_in_poch[2] += diffWithTargetY * dy_theta2(x1, x2);
                    }
    
                    for(var i=0;i<weights.Length;i++)
                        weights[i] += lr * weights_in_poch[i];
    
                    float totalErrorCost = 0f;
                    foreach (var x_y in batch)
                    {
                        var x1 = x_y.Item1.First();
                        var x2 = x_y.Item1.Skip(1).Take(1).First();
                        var target_y = x_y.Item2;
    
                        float diffWithTargetY = target_y - fun(x1, x2, weights[1], weights[2], weights[0]);
                        totalErrorCost += (float)System.Math.Pow(diffWithTargetY, 2)/2;
                    }
    
                    float cost = totalErrorCost / batch.Count;
    
                    if (System.Math.Abs(cost - lastCost) <= epsilon)
                    {
                        Console.WriteLine(string.Format("EPOCH {0}", epoch_i));
                        Console.WriteLine(string.Format("LAST MSE {0}", lastCost));
                        Console.WriteLine(string.Format("MSE {0}", cost));
                        break;
                    }
    
                    lastCost = cost;
    
                    if (epoch_i % 100 == 0|| epoch_i==epoch)
                    {
                        Console.WriteLine(string.Format("MSE {0}", cost));
                    }
                }
    
                print(weights[1], weights[2], weights[0]);
    
                Console.ReadLine();
            }
    
            private static List<Tuple<float[], float>> GetRandomBatch(List<float[]> inputs_x, List<float> inputs_y, int maxCount)
            {
                List<Tuple<float[], float>> lst = new List<Tuple<float[], float>>();
    
                System.Random rnd = new Random((int)DateTime.Now.Ticks);
    
                int count = 0;
                while (count<maxCount)
                {
                    int rndIndex = rnd.Next(inputs_x.Count);
                    var item=Tuple.Create<float[], float>(inputs_x[rndIndex], inputs_y[rndIndex]);
                    lst.Add(item);
                    count++;
                }
    
                return lst;
            }
    
            private static void print(float theta1, float theta2, float b)
            {
                Console.WriteLine(string.Format("y={0}*x1+{1}*x2+{2}", theta1, theta2, b));
            }
            private static float fun(float x1, float x2, float theta1, float theta2, float b)
            {
                return theta1 * x1 + theta2 * x2 + b;
            }
            private static float dy_theta1(float x1, float x2)
            {
                return x1;
            }
    
            private static float dy_theta2(float x1, float x2)
            {
                return x2;
            }
    
            private static float dy_b(float x1, float x2)
            {
                return 1;
            }
        }
    }
    

      

  • 相关阅读:
    为什么不要在Spring的配置里,配置上XSD的版本号
    使用GIT管理UE4代码
    C++ 编程错误记录
    Maven 命令及其他备忘
    Windows API 之 CreateToolhelp32Snapshot
    Windows API 之 ReadProcessMemory
    Windows API 之 OpenProcessToken、GetTokenInformation
    利用未文档化API:RtlAdjustPrivilege 提权实现自动关机
    WindowsAPI 之 CreatePipe、CreateProcess
    错误: error C4996: 'strcpy': This function or variable may be unsafe. Consider using strcpy_s instead. 的处理方法
  • 原文地址:https://www.cnblogs.com/aarond/p/7936523.html
Copyright © 2011-2022 走看看