zoukankan      html  css  js  c++  java
  • PyTorch 手动提取 Layers

    Model

    NeuralNet(
      (l0): Linear(in_features=6, out_features=256, bias=True)
      (relu): ReLU()
      (bn0): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (l00): Linear(in_features=256, out_features=1, bias=True)
    )
    

    Extract layers

    model.eval() # In case of BatchNorm and Dropout
    
     # ReLU activation
    ReLU = lambda x: np.maximum(0.0, x)
    # GPU torch.Tensor to CPU numpy ndarray
    X_data = X_valid.cpu().numpy()
    
    # Fully-connected layer
    W0 = model.l0.weight.cpu().detach().numpy() # Weights W
    b0 = model.l0.bias.cpu().detach().numpy() # Bias b
    
    # Batch Normalization Layer
    bn_mean = model.bn0.running_mean.cpu().numpy()
    bn_mean = np.reshape(bn_mean, (256, -1))
    bn_var = model.bn0.running_var.cpu().numpy()
    bn_var = np.reshape(bn_var, (256, -1))
    bn_gamma = model.bn0.weight.cpu().detach().numpy()
    bn_gamma = np.reshape(bn_gamma, (256, -1))
    bn_beta = model.bn0.bias.cpu().detach().numpy()
    bn_beta = np.reshape(bn_beta, (256, -1))
    bn_epsilon = model.bn0.eps
    
    # Final output layer
    W00 = model.l00.weight.cpu().detach().numpy()
    b00 = model.l00.bias.cpu().detach().numpy()
    

    Feed-forward calculation

    # First output
    out = np.dot(W0, np.transpose(X_data)) + np.tile(np.reshape(b0, (-1, 1)), X_data.shape[0])
    out = np.array(list(map(ReLU, out)))
    
    # BatchNorm layer
    out = (out-bn_mean)/np.sqrt(bn_var)*bn_gamma+bn_beta # correct formula
    
    # Final output
    out = np.dot(W00, L0) + np.tile(np.reshape(b00, (-1, 1)), X_data.shape[0])
    out = np.array(list(map(ReLU, out)))
    
  • 相关阅读:
    怎么查看京东店铺的品牌ID
    PPT编辑的时候很卡,放映的时候不卡,咋回事?
    codevs 1702素数判定2
    codevs 2530大质数
    codevs 1488GangGang的烦恼
    codevs 2851 菜菜买气球
    hdu 5653 Bomber Man wants to bomb an Array
    poj 3661 Running
    poj 1651 Multiplication Puzzle
    hdu 2476 String Painter
  • 原文地址:https://www.cnblogs.com/yaos/p/10525032.html
Copyright © 2011-2022 走看看