zoukankan      html  css  js  c++  java
  • 网络训练至某个epoch,参数 问题

    1 start_epoch = params.start_epoch
    2   stop_epoch = params.stop_epoch
    3   if params.resume != '':
    4     resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch)  #get_resume_file函数得到epoch.tar文件
    5     if resume_file is not None:
    6       tmp = torch.load(resume_file)
    7       start_epoch = tmp['epoch']+1
    8       model.load_state_dict(tmp['state'])
    9       print('  resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume))

    对tar文件进行加载,并且选取其中需要的字典权重

     1 tmp = torch.load(modelfile)   # load parameter file:400.tar
     2   try:
     3     state = tmp['state']
     4   except KeyError:
     5     state = tmp['model_state']
     6   except:
     7     raise
     8   state_keys = list(state.keys())  #列举字典中的key
     9   for i, key in enumerate(state_keys):
    10     if "feature." in key and not 'gamma' in key and not 'beta' in key:
    11       newkey = key.replace("feature.","")
    12       state[newkey] = state.pop(key)  #删除该key并返回对应的值,不影响上面的训练
    13     else:
    14       state.pop(key)
    15 
    16   model.load_state_dict(state) 
  • 相关阅读:
    字段与表的对应关系
    java初学代码,还不太熟练
    编程学习心得
    ps中经常遇到的问题
    R语言矩阵运算加速
    写代码过程中一些数字推理公式
    EXCEL中常用的函数
    css样式中常见的属性
    R语言的基本矩阵运算
    excel常用的函数
  • 原文地址:https://www.cnblogs.com/stepping/p/13403741.html
Copyright © 2011-2022 走看看