zoukankan      html  css  js  c++  java
  • PyTorch 手动提取 Layers

    Model

    NeuralNet(
      (l0): Linear(in_features=6, out_features=256, bias=True)
      (relu): ReLU()
      (bn0): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (l00): Linear(in_features=256, out_features=1, bias=True)
    )
    

    Extract layers

    model.eval() # In case of BatchNorm and Dropout
    
     # ReLU activation
    ReLU = lambda x: np.maximum(0.0, x)
    # GPU torch.Tensor to CPU numpy ndarray
    X_data = X_valid.cpu().numpy()
    
    # Fully-connected layer
    W0 = model.l0.weight.cpu().detach().numpy() # Weights W
    b0 = model.l0.bias.cpu().detach().numpy() # Bias b
    
    # Batch Normalization Layer
    bn_mean = model.bn0.running_mean.cpu().numpy()
    bn_mean = np.reshape(bn_mean, (256, -1))
    bn_var = model.bn0.running_var.cpu().numpy()
    bn_var = np.reshape(bn_var, (256, -1))
    bn_gamma = model.bn0.weight.cpu().detach().numpy()
    bn_gamma = np.reshape(bn_gamma, (256, -1))
    bn_beta = model.bn0.bias.cpu().detach().numpy()
    bn_beta = np.reshape(bn_beta, (256, -1))
    bn_epsilon = model.bn0.eps
    
    # Final output layer
    W00 = model.l00.weight.cpu().detach().numpy()
    b00 = model.l00.bias.cpu().detach().numpy()
    

    Feed-forward calculation

    # First output
    out = np.dot(W0, np.transpose(X_data)) + np.tile(np.reshape(b0, (-1, 1)), X_data.shape[0])
    out = np.array(list(map(ReLU, out)))
    
    # BatchNorm layer
    out = (out-bn_mean)/np.sqrt(bn_var)*bn_gamma+bn_beta # correct formula
    
    # Final output
    out = np.dot(W00, L0) + np.tile(np.reshape(b00, (-1, 1)), X_data.shape[0])
    out = np.array(list(map(ReLU, out)))
    
  • 相关阅读:
    INV*更新物料信息
    WPF设置样式的几种方式
    使用InternetGetConnectedState判断本地网络状态(C#举例)
    WinInet API详解
    WPF导航总结
    WPF中的命令与命令绑定导航
    WPF依赖属性相关博客导航
    关于WPF自定义控件(导航)
    WPF送走控件的focus方法
    MvvmLight学习篇—— Mvvm Light Toolkit for wpf/silverlight系列(导航)
  • 原文地址:https://www.cnblogs.com/yaos/p/14014232.html
Copyright © 2011-2022 走看看