zoukankan      html  css  js  c++  java
  • 最速下降法+Matlab代码

    算法原理

    to-do

    Matlab代码

    clc; clear;
    
    f = @(x) x(1).^2+2*x(1)*x(2)+3*x(2).^2; %待求函数,x1,x2,x3...
    % f = @(x) x(1).^2+2*x(2).^2;
    paraNum = 2; %函数参数的个数,x1,x2,x3...的个数
    x0 = [3,3]; %初始值
    tol = 1e-5; %迭代容忍度
    flag = inf; %结束条件
    error = []; %函数变化
    
    while flag > tol
        p = g(f,x0,paraNum); %列向量
        f2 = @(a) f(x0-a*p');
        buChang = argmin(f2); %求步长,line search:argmin function
        x1 = x0-buChang*p';
        flag = norm(x1-x0);
        error = [error,flag];
        x0 = x1;
    end
    plot(0:length(error)-1,error)
    
    function [f_grad] = g(f,x0,paraNum)
    temp = sym('x',[1,paraNum]);
    f1=f(temp);
    Z = gradient(f1);
    f_grad = double(subs(Z,temp,x0));
    end
    
    function [x] = argmin(f)
    %求步长
    t = 0;
    options = optimset('Display','off');
    [x,~] = fminunc(f,t,options);
    end
    

    代码问题

    1. Matlab符号运算,耗时
    2. 最速下降法的步长使用line-search,耗时

    代码改进

    clc; clear;
    
    f = @(x) x(1).^2+2*x(1)*x(2)+3*x(2).^2; %待求函数,x1,x2,x3...
    % f = @(x) x(1).^2+2*x(2).^2;
    paraNum = 2; %函数参数的个数,x1,x2,x3...的个数
    x0 = [3,3]; %初始值
    tol = 1e-3; %迭代容忍度
    flag = inf; %结束条件
    error = []; %函数变化
    
    while flag > tol
    % for i =1:1
        p = g(f,x0,paraNum); %列向量    
        if norm(p) < tol
                buChang = 0;
        else
            buChang = argmin(f,x0,p,paraNum); %求步长,line search:argmin function
        end
        x1 = x0-buChang.*p';
        flag = norm(x1-x0);
        error = [error,flag];
        x0 = x1;
    end
    plot(0:length(error)-1,error)
    
    function [f_grad] = g(f,x0,paraNum)
    temp = sym('x',[1,paraNum]);
    f1=f(temp);
    Z = gradient(f1);
    f_grad = double(subs(Z,temp,x0));
    end
    
    % function [x] = argmin(f,paraNum)
    % %求步长
    % t = zeros(1,paraNum);
    % options = optimset('Display','off');
    % [x,~] = fminunc(f,t,options);
    % end
    
    function [x] = argmin(f,x0,p,num)
    % 求步长
    % for i=1:paraNum
    %     syms(['x',num2str(i)]);
    % end
    temp = sym('x',[1,num]);
    f1=f(x0 - temp.*p');
    for i = 1:num
        temp(i) = diff(f1,temp(i));
    end
    jieGuo = solve(temp);
    jieGuo = struct2cell(jieGuo);
    x = zeros(1,num);
    for i = 1:num
        x(i) = double(jieGuo{i});
    end
    end
    
  • 相关阅读:
    数组的操作方法
    数组遍历的方法以及区别
    组件内的守卫
    路由守卫
    软件工程
    java web (j2ee)学习路线 —— 将青春交给命运
    团队作业(一)- 第五组
    软件工程
    软件工程-第二次作业
    java局部/成员/静态/实例变量
  • 原文地址:https://www.cnblogs.com/kexve/p/11737898.html
Copyright © 2011-2022 走看看