zoukankan      html  css  js  c++  java
  • pytorch怎么抽取中间的特征或者梯度

    for i, (input, target) in enumerate(trainloader):

    # measure data loading time
    data_time.update(time.time() - end)

    
    

    input, target = input.cuda(), target.cuda()
    if i==2:
    def for_hook(module,input, output):
    print('output values:',output)
    handle2 = model.module.conv1.register_forward_hook(for_hook)

    
    

    # compute output
    output = model(input)
    # output = output1*2

    if i==2:
    def variable_hook(grad):
    print('grad:',grad)
    hook_handle = output.register_hook(variable_hook)

    
    

    # output2 = 2*output_
    # output = 0.5*output2
    loss = criterion(output, target)

    
    
    
    
    
    
    
    


    # measure accuracy and record loss
    prec = accuracy(output, target)[0]
    losses.update(loss.item(), input.size(0))
    top1.update(prec.item(), input.size(0))

    
    

    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    # print('output.grad:',output1.grad)

    
    

    # print('input.grad:',input.grad)
    # print('input.is_leaf:',input.is_leaf)
    # output.register_hook(print)
    # zz.backward()

    
    
    
    
    

    # print("output.grad:",output.grad)
    optimizer.step()

    
    

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()

    
    

    if i==2:
    # print('the input is :', input)
    print('the output is :', output)
    hook_handle.remove()
    handle2.remove()
    # print('the target is :', target)
    # print('parameters:',optimizer.param_groups)

    
    

    if i % args.print_freq == 0:
    print('Epoch: [{0}][{1}/{2}] '
    'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
    'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
    'Loss {loss.val:.4f} ({loss.avg:.4f}) '
    'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
    epoch, i, len(trainloader), batch_time=batch_time,
    data_time=data_time, loss=losses, top1=top1))

     

    解释:

    #定义构前向函数
    def for_hook(module,input, output):
        print('output values:',output)
    #什么要抽取的层 model.module.avgpool
    handle2 = model.module.avgpool.register_forward_hook(for_hook) 
    #前向,为勾函数准备
    output = model(input)
    #删除勾函数
    handle2.remove()

  • 相关阅读:
    文件系统操作与磁盘管理
    文件打包与解压缩
    环境变量与文件查找
    Linux目录结构及文件基本操作
    vim3
    vim2
    vim1
    用户管理
    初识
    第一章
  • 原文地址:https://www.cnblogs.com/Wanggcong/p/10269823.html
Copyright © 2011-2022 走看看