zoukankan      html  css  js  c++  java
  • 最速下降法+Matlab代码

    算法原理

    to-do

    Matlab代码

    clc; clear;
    
    f = @(x) x(1).^2+2*x(1)*x(2)+3*x(2).^2; %待求函数,x1,x2,x3...
    % f = @(x) x(1).^2+2*x(2).^2;
    paraNum = 2; %函数参数的个数,x1,x2,x3...的个数
    x0 = [3,3]; %初始值
    tol = 1e-5; %迭代容忍度
    flag = inf; %结束条件
    error = []; %函数变化
    
    while flag > tol
        p = g(f,x0,paraNum); %列向量
        f2 = @(a) f(x0-a*p');
        buChang = argmin(f2); %求步长,line search:argmin function
        x1 = x0-buChang*p';
        flag = norm(x1-x0);
        error = [error,flag];
        x0 = x1;
    end
    plot(0:length(error)-1,error)
    
    function [f_grad] = g(f,x0,paraNum)
    temp = sym('x',[1,paraNum]);
    f1=f(temp);
    Z = gradient(f1);
    f_grad = double(subs(Z,temp,x0));
    end
    
    function [x] = argmin(f)
    %求步长
    t = 0;
    options = optimset('Display','off');
    [x,~] = fminunc(f,t,options);
    end
    

    代码问题

    1. Matlab符号运算,耗时
    2. 最速下降法的步长使用line-search,耗时

    代码改进

    clc; clear;
    
    f = @(x) x(1).^2+2*x(1)*x(2)+3*x(2).^2; %待求函数,x1,x2,x3...
    % f = @(x) x(1).^2+2*x(2).^2;
    paraNum = 2; %函数参数的个数,x1,x2,x3...的个数
    x0 = [3,3]; %初始值
    tol = 1e-3; %迭代容忍度
    flag = inf; %结束条件
    error = []; %函数变化
    
    while flag > tol
    % for i =1:1
        p = g(f,x0,paraNum); %列向量    
        if norm(p) < tol
                buChang = 0;
        else
            buChang = argmin(f,x0,p,paraNum); %求步长,line search:argmin function
        end
        x1 = x0-buChang.*p';
        flag = norm(x1-x0);
        error = [error,flag];
        x0 = x1;
    end
    plot(0:length(error)-1,error)
    
    function [f_grad] = g(f,x0,paraNum)
    temp = sym('x',[1,paraNum]);
    f1=f(temp);
    Z = gradient(f1);
    f_grad = double(subs(Z,temp,x0));
    end
    
    % function [x] = argmin(f,paraNum)
    % %求步长
    % t = zeros(1,paraNum);
    % options = optimset('Display','off');
    % [x,~] = fminunc(f,t,options);
    % end
    
    function [x] = argmin(f,x0,p,num)
    % 求步长
    % for i=1:paraNum
    %     syms(['x',num2str(i)]);
    % end
    temp = sym('x',[1,num]);
    f1=f(x0 - temp.*p');
    for i = 1:num
        temp(i) = diff(f1,temp(i));
    end
    jieGuo = solve(temp);
    jieGuo = struct2cell(jieGuo);
    x = zeros(1,num);
    for i = 1:num
        x(i) = double(jieGuo{i});
    end
    end
    
  • 相关阅读:
    sqlhelper使用指南
    大三学长带我学习JAVA。作业1. 第1讲.Java.SE入门、JDK的下载与安装、第一个Java程序、Java程序的编译与执行 大三学长带我学习JAVA。作业1.
    pku1201 Intervals
    hdu 1364 king
    pku 3268 Silver Cow Party
    pku 3169 Layout
    hdu 2680 Choose the best route
    hdu 2983
    pku 1716 Integer Intervals
    pku 2387 Til the Cows Come Home
  • 原文地址:https://www.cnblogs.com/kexve/p/11737898.html
Copyright © 2011-2022 走看看