zoukankan      html  css  js  c++  java
  • kaggle gradient_descent

    kaggle gradient_descent

    1.描述

    自写梯度下降

    2.代码

    import numpy as np
    import matplotlib.pyplot as plt
    
    # train_X = np.array([[1,2,3,4,5,6,7,8,9,10],[1,2,3,4,5,6,7,8,9,10]]).T
    # train_y = np.array([2,4,6,8,10,12,14,16,18,20]).T
    # test_X = np.array([[2,4,12,11],[3,6,3,9]]).T # 5 10 15 20
    
    train_X = np.random.randn(1000,10)
    train_y = np.random.randn(1000,1)
    test_X =  np.random.randn(1000,10)
    
    step_len = 0.1
    max_iterations = 100000
    epsilon = 1e-7
    
    def ComputeCost(X,y,theta):
        tmp = X.dot(theta)-y.reshape(y.shape[0],1)
        return 1/(2*len(y))*sum((tmp*tmp))
    
    def GradientDescent(X,y,step_len,max_iterations):
        X = np.array(X)
        y = np.array(y)
        X = np.column_stack( (np.ones((len(y),1)),X))
        theta = np.zeros((X.shape[1],1))
        m = len(y)
        J_his = []
        for i in range(0,max_iterations):
            tmp = X.dot(theta)-y.reshape(y.shape[0],1)
            theta = theta - step_len / m * X.T.dot(tmp)
            J_his.append(ComputeCost(X,y,theta))
            #print(J_his[-1])
            if(len(J_his)>=2 and J_his[-2] - J_his[-1] >= 0 and J_his[-2] - J_his[-1] <= epsilon):
                print('已收敛')
                break
            if(len(J_his)>=2 and J_his[-1] - J_his[-2] >= 0):
                print('步长过大')
                break
        return theta,J_his
    
    def Predict(X,theta):
        one = np.ones((X.shape[0],1))
        X = np.column_stack(( one,X ))
        return X.dot(theta)
    
    def Normalizetion(x):
        sum_tmp = np.sum(x,axis=0)
        max_tmp = np.max(x,axis=0)
        min_tmp = np.min(x,axis=0)
        ave_tmp = np.average(x,axis=0)
        return (x - ave_tmp)/(max_tmp-min_tmp)
    #############################################################################
    
    train_X = Normalizetion(train_X)
    theta,J_his = GradientDescent(train_X,train_y,step_len,max_iterations)
    # print('theta =',theta,'
    ')
    
    # print(Predict(test_X,theta))
    
    train_time = range(0,len(J_his))
    plt.plot(train_time, J_his)
    plt.xlabel('train_time')
    plt.ylabel('cost_fun_J')
    plt.show()
    
  • 相关阅读:
    DAY1 linux 50条命令
    安卓2.0,3.0,4.0的差别
    java历史
    晶体管共射极单管放大电路
    jquery取消选择select下拉框
    oarcle数据库导入导出,创建表空间
    360chrome,google chrome浏览器使用jquery.ajax加载本地html文件
    jquery 选择器
    nodejs 相关
    关于http请求
  • 原文地址:https://www.cnblogs.com/cbattle/p/8810701.html
Copyright © 2011-2022 走看看