zoukankan      html  css  js  c++  java
  • 梯度下降代码

    https://blog.csdn.net/panghaomingme/article/details/79384922

     1 #!usr/bin/python3
     2 # coding:utf-8
     3  
     4 # BGD 批梯度下降代码实现
     5 # SGD 随机梯度下降代码实现
     6 import numpy as np
     7  
     8 import random
     9  
    10  
    11 def batchGradientDescent(x, y, theta, alpha, m, maxInteration):
    12     x_train = x.transpose()
    13     for i in range(0, maxInteration):
    14         hypothesis = np.dot(x, theta)
    15         # 损失函数
    16         loss = hypothesis - y
    17         # 下降梯度
    18         gradient = np.dot(x_train, loss) / m
    19         # 求导之后得到theta
    20         theta = theta - alpha * gradient
    21     return theta
    22  
    23  
    24 def stochasticGradientDescent(x, y, theta, alpha, m, maxInteration):
    25     data = []
    26     for i in range(4):
    27         data.append(i)
    28     x_train = x.transpose()
    29     for i in range(0, maxInteration):
    30         hypothesis = np.dot(x, theta)
    31         # 损失函数
    32         loss = hypothesis - y
    33         # 选取一个随机数
    34         index = random.sample(data, 1)
    35         index1 = index[0]
    36         # 下降梯度
    37         gradient = loss[index1] * x[index1]
    38         # 求导之后得到theta
    39         theta = theta - alpha * gradient
    40     return theta
    41  
    42  
    43 def main():
    44     trainData = np.array([[1, 4, 2], [2, 5, 3], [5, 1, 6], [4, 2, 8]])
    45     trainLabel = np.array([19, 26, 19, 20])
    46     print(trainData)
    47     print(trainLabel)
    48     m, n = np.shape(trainData)
    49     theta = np.ones(n)
    50     print(theta.shape)
    51     maxInteration = 500
    52     alpha = 0.01
    53     theta1 = batchGradientDescent(trainData, trainLabel, theta, alpha, m, maxInteration)
    54     print(theta1)
    55     theta2 = stochasticGradientDescent(trainData, trainLabel, theta, alpha, m, maxInteration)
    56     print(theta2)
    57     return
    58  
    59  
    60 if __name__ == "__main__":
    61     main()
  • 相关阅读:
    重大技术需求系统八
    2020年下半年软考真题及答案解析
    周总结五
    重大技术需求系统七
    TextWatcher 编辑框监听器
    Android四大基本组件介绍与生命周期
    JAVA String,StringBuffer与StringBuilder的区别??
    iOS开发:保持程序在后台长时间运行
    宏定义的布局约束
    随便说一些
  • 原文地址:https://www.cnblogs.com/zhangbojiangfeng/p/9474298.html
Copyright © 2011-2022 走看看