zoukankan      html  css  js  c++  java
  • matlab编程代写使用贝叶斯优化的深度学习

    原文链接:http://tecdat.cn/?p=7954

    此示例说明如何将贝叶斯优化应用于深度学习,以及如何为卷积神经网络找到最佳网络超参数和训练选项。

    要训​​练深度神经网络,必须指定神经网络架构以及训练算法的选项。选择和调整这些超参数可能很困难并且需要时间。贝叶斯优化是一种非常适合用于优化分类和回归模型的超参数的算法。 

    准备数据

    下载CIFAR-10数据集[1]。该数据集包含60,000张图像,每个图像的大小为32 x 32和三个颜色通道(RGB)。整个数据集的大小为175 MB。 

    加载CIFAR-10数据集作为训练图像和标签,并测试图像和标签。 

    [XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir);

    idx = randperm(numel(YTest),5000);
    XValidation = XTest(:,:,:,idx);
    XTest(:,:,:,idx) = [];
    YValidation = YTest(idx);
    YTest(idx) = [];

    您可以使用以下代码显示训练图像的样本。

    figure;
    idx = randperm(numel(YTrain),20);
    for i = 1:numel(idx)
        subplot(4,5,i);
        imshow(XTrain(:,:,:,idx(i)));
    end

    选择要优化的变量

    选择要使用贝叶斯优化进行优化的变量,并指定要搜索的范围。此外,指定变量是否为整数以及是否在对数空间中搜索区间。优化以下变量:

    • 网络部分的深度。此参数控制网络的深度。该网络具有三个部分,每个部分具有SectionDepth相同的卷积层。因此,卷积层的总数为3*SectionDepth。脚本后面的目标函数将每一层中的卷积过滤器数量与成正比1/sqrt(SectionDepth)。结果,对于不同的截面深度,每次迭代的参数数量和所需的计算量大致相同。

    •  最佳学习率取决于您的数据以及您正在训练的网络。

    • 随机梯度下降动量。 

    • L2正则化强度。 

    optimVars = [
    optimizableVariable('SectionDepth',[1 3],'Type','integer')
    optimizableVariable('InitialLearnRate',[1e-2 1],'Transform','log')
    optimizableVariable('Momentum',[0.8 0.98])
    optimizableVariable('L2Regularization',[1e-10 1e-2],'Transform','log')];

    执行贝叶斯优化

    使用训练和验证数据作为输入,为贝叶斯优化器创建目标函数。目标函数训练卷积神经网络,并在验证集上返回分类误差。 

    ObjFcn = makeObjFcn(XTrain,YTrain,XValidation,YValidation);

    通过最小化验证集上的分类误差来执行贝叶斯优化。 为了充分利用贝叶斯优化的功能,您应该至少执行30个目标函数评估。 

    每个网络完成训练后,bayesopt将结果打印到命令窗口。bayesopt然后该函数返回中的文件名BayesObject.UserDataTrace。目标函数将训练有素的网络保存到磁盘,并将文件名返回给bayesopt


     
    |===================================================================================================================================|
    | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | SectionDepth | InitialLearn-| Momentum | L2Regulariza-|
    | | result | | runtime | (observed) | (estim.) | | Rate | | tion |
    |===================================================================================================================================|
    | 1 | Best | 0.19 | 2201 | 0.19 | 0.19 | 3 | 0.012114 | 0.8354 | 0.0010624 |
    |    2 | Accept |      0.3224 |      1734.1 |        0.19 |     0.19636 |            1 |     0.066481 |      0.88231 |    0.0026626 |
    |    3 | Accept |      0.2076 |      1688.7 |        0.19 |     0.19374 |            2 |     0.022346 |      0.91149 |    8.242e-10 |
    |    4 | Accept |      0.1908 |      2167.2 |        0.19 |      0.1904 |            3 |      0.97586 |      0.83613 |   4.5143e-08 |
    |    5 | Accept |      0.1972 |      2157.4 |        0.19 |     0.19274 |            3 |      0.21193 |      0.97995 |   1.4691e-05 |
    |    6 | Accept |      0.2594 |      2152.8 |        0.19 |        0.19 |            3 |      0.98723 |      0.97931 |   2.4847e-10 |
    |    7 | Best   |      0.1882 |      2257.5 |      0.1882 |     0.18819 |            3 |       0.1722 |       0.8019 |   4.2149e-06 |
    |    8 | Accept |      0.8116 |      1989.7 |      0.1882 |     0.18818 |            3 |      0.42085 |      0.95355 |    0.0092026 |
    |    9 | Accept |      0.1986 |        1836 |      0.1882 |     0.18821 |            2 |     0.030291 |      0.94711 |   2.5062e-05 |
    |   10 | Accept |      0.2146 |      1909.4 |      0.1882 |     0.18816 |            2 |     0.013379 |       0.8785 |   7.6354e-09 |
    |   11 | Accept |      0.2194 |        1562 |      0.1882 |     0.18815 |            1 |      0.14682 |      0.86272 |   8.6242e-09 |
    |   12 | Accept |      0.2246 |      1591.2 |      0.1882 |     0.18813 |            1 |      0.70438 |      0.82809 |   1.0102e-06 |
    |   13 | Accept |      0.2648 |      1621.8 |      0.1882 |     0.18824 |            1 |     0.010109 |      0.89989 |   1.0481e-10 |
    |   14 | Accept |      0.2222 |        1562 |      0.1882 |     0.18812 |            1 |      0.11058 |      0.97432 |   2.4101e-07 |
    |   15 | Accept |      0.2364 |      1625.7 |      0.1882 |     0.18813 |            1 |     0.079381 |       0.8292 |   2.6722e-05 |
    |   16 | Accept |        0.26 |      1706.2 |      0.1882 |     0.18815 |            1 |     0.010041 |      0.96229 |   1.1066e-05 |
    |   17 | Accept |      0.1986 |      2188.3 |      0.1882 |     0.18635 |            3 |      0.35949 |      0.97824 |    3.153e-07 |
    |   18 | Accept |      0.1938 |      2169.6 |      0.1882 |     0.18817 |            3 |     0.024365 |      0.88464 |   0.00024507 |
    |   19 | Accept |      0.3588 |      1713.7 |      0.1882 |     0.18216 |            1 |     0.010177 |      0.89427 |    0.0090342 |
    |   20 | Accept |      0.2224 |      1721.4 |      0.1882 |     0.18193 |            1 |      0.09804 |      0.97947 |   1.0727e-10 |
    |===================================================================================================================================|
    | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | SectionDepth | InitialLearn-| Momentum | L2Regulariza-|
    | | result | | runtime | (observed) | (estim.) | | Rate | | tion |
    |===================================================================================================================================|
    | 21 | Accept | 0.1904 | 2184.7 | 0.1882 | 0.18498 | 3 | 0.017697 | 0.95057 | 0.00022247 |
    |   22 | Accept |      0.1928 |      2184.4 |      0.1882 |     0.18527 |            3 |      0.06813 |       0.9027 |   1.3521e-09 |
    |   23 | Accept |      0.1934 |      2183.6 |      0.1882 |      0.1882 |            3 |     0.018269 |      0.90432 |    0.0003573 |
    |   24 | Accept |       0.303 |      1707.9 |      0.1882 |     0.18809 |            1 |     0.010157 |      0.88226 |   0.00088737 |
    |   25 | Accept |       0.194 |      2189.1 |      0.1882 |     0.18808 |            3 |     0.019354 |      0.94156 |   9.6197e-07 |
    |   26 | Accept |      0.2192 |      1752.2 |      0.1882 |     0.18809 |            1 |      0.99324 |      0.91165 |   1.1521e-08 |
    |   27 | Accept |      0.1918 |        2185 |      0.1882 |     0.18813 |            3 |      0.05292 |       0.8689 |   1.2449e-05 |
    __________________________________________________________
    Optimization completed.
    MaxTime of 50400 seconds reached.
    Total function evaluations: 27
    Total elapsed time: 51962.3666 seconds.
    Total objective function evaluation time: 51942.8833

    Best observed feasible point:
    SectionDepth InitialLearnRate Momentum L2Regularization
    ____________ ________________ ________ ________________

    3 0.1722 0.8019 4.2149e-06

    Observed objective function value = 0.1882
    Estimated objective function value = 0.18813
    Function evaluation time = 2257.4627

    Best estimated feasible point (according to models):
    SectionDepth InitialLearnRate Momentum L2Regularization
    ____________ ________________ ________ ________________

    3 0.1722 0.8019 4.2149e-06

    Estimated objective function value = 0.18813
    Estimated function evaluation time = 2166.2402

    评估最终网络

    加载优化中发现的最佳网络及其验证准确性。

    valError = 0.1882

    预测测试集的标签并计算测试误差。将测试集中每个图像的分类视为具有一定成功概率的独立事件,这意味着错误分类的图像数量遵循二项式分布。使用它来计算标准误差(testErrorSE)和testError95CI广义误差率的大约95%置信区间()。这种方法通常称为Wald方法。 


     
    testError = 0.1864

     
    testError95CI = 1×2

    0.1756 0.1972

    绘制混淆矩阵以获取测试数据。通过使用列和行摘要显示每个类的精度和召回率。


     

    您可以使用以下代码显示一些测试图像及其预测的类以及这些类的概率。

    优化目标函数

    定义用于优化的目标函数。 

    定义卷积神经网络架构。

    • 在卷积层上添加填充,以便空间输出大小始终与输入大小相同。

    • 每次使用最大池化层对空间维度进行2倍的下采样时,将过滤器的数量增加2倍。这样做可确保每个卷积层所需的计算量大致相同。

    • 选择与成正比的滤波器数量,以1/sqrt(SectionDepth)使不同深度的网络具有大致相同数量的参数,并且每次迭代所需的计算量大致相同。要增加网络参数的数量和整体网络灵活性,请增加numF。要训​​练更深的网络,请更改SectionDepth变量的范围。

    • 使用convBlock(filterSize,numFilters,numConvLayers)创建的块numConvLayers卷积层,每个具有指定filterSizenumFilters过滤器,并且每个随后分批正常化层和RELU层。该convBlock函数在本示例的末尾定义。

     指定验证数据,然后选择一个'ValidationFrequency'值,以便trainNetwork每个时期对网络进行一次验证。训练固定的时期数,并在最后一个时期将学习率降低10倍。这减少了参数更新的噪音,并使网络参数的沉降更接近损耗函数的最小值。

       

    使用数据增强可沿垂直轴随机翻转训练图像,并将它们随机水平和垂直转换为四个像素。

       

    训练网络并在训练过程中绘制训练进度。

          

    在验证集上评估经过训练的网络,计算预测的图像标签,并在验证数据上计算错误率。


     

    创建一个包含验证错误的文件名,然后将网络,验证错误和培训选项保存到磁盘。目标函数fileName作为输出参数bayesopt返回,并返回中的所有文件名BayesObject.UserDataTrace

    convBlock函数创建一个numConvLayers卷积层块,每个卷积层都有一个指定的filterSizenumFilters过滤器,每个卷积层后面都有一个批处理归一化层和一个ReLU层。


     

    参考文献

    [1]克里热夫斯基,亚历克斯。“从微小的图像中学习多层功能。” (2009)。https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf

    如果您有任何疑问,请在下面发表评论。   

  • 相关阅读:
    【Azure Redis 缓存】Azure Redis 功能性讨论二
    【Azure Developer】如何用Microsoft Graph API管理AAD Application里面的Permissions
    【Azure 环境】通过Python SDK收集所有订阅简略信息,例如订阅id 名称, 资源组及组内资源信息等,如何给Python应用赋予相应的权限才能获取到信息呢?
    【Azure 应用服务】App Service与APIM同时集成到同一个虚拟网络后,如何通过内网访问内部VNET的APIM呢?
    【Azure 云服务】如何从Azure Cloud Service中获取项目的部署文件
    【Azure Redis 缓存】Azure Redis 异常
    【Azure 微服务】基于已经存在的虚拟网络(VNET)及子网创建新的Service Fabric并且为所有节点配置自定义DNS服务
    【Azure Redis 缓存】遇见Azure Redis不能创建成功的问题:至少一个资源部署操作失败,因为 Microsoft.Cache 资源提供程序未注册。
    【Azure Redis 缓存】如何得知Azure Redis服务有更新行为?
    【Azure API 管理】在 Azure API 管理中使用 OAuth 2.0 授权和 Azure AD 保护 Web API 后端,在请求中携带Token访问后报401的错误
  • 原文地址:https://www.cnblogs.com/tecdat/p/11720456.html
Copyright © 2011-2022 走看看