zoukankan      html  css  js  c++  java
  • How to Initialize Neural Networks in PyTorch with Pretrained Nets in TensorFlow or Theano

    First convert network weights and biases to numpy arrays. Note if you want to load a pre-trained network with Keras, you must define it of the same network structure with Keras.

    Note which backend of Keras you use. I install Keras with TensorFlow backend but the VGGnet I'm going to replicate has Theano order.

    from keras.models import Sequential
    from keras.layers.core import Flatten, Dense, Dropout
    from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D
    from keras.optimizers import SGD
    import numpy as np
    
    def VGG_16(weights_path=None):
        model = Sequential()
        model.add(ZeroPadding2D((1,1),input_shape=(3,224,224),dim_ordering="th"))
        model.add(Convolution2D(64, 3, 3, activation='relu',dim_ordering="th",name="conv1"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(64, 3, 3, activation='relu',dim_ordering="th",name="conv2"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(128, 3, 3, activation='relu',dim_ordering="th",name="conv3"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(128, 3, 3, activation='relu',dim_ordering="th",name="conv4"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(256, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(256, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(256, 3, 3, activation='relu',dim_ordering="th"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(Flatten())
        model.add(Dense(4096, activation='relu'))
        model.add(Dropout(0.5))
        model.add(Dense(4096, activation='relu'))
        model.add(Dropout(0.5))
        model.add(Dense(1000, activation='softmax'))
    
        if weights_path:
            model.load_weights(weights_path)
    
        return model
    
    
    if __name__ == '__main__':
        weights_path="vgg16_weights.h5"
        model=VGG_16(weights_path=weights_path)
        weight1,bias1=model.get_layer(name="conv1").get_weights()
        weight2,bias2=model.get_layer(name="conv2").get_weights()
        weight3,bias3=model.get_layer(name="conv3").get_weights()
        weight4,bias4=model.get_layer(name="conv4").get_weights()
        np.save("weight1", weight1)
        np.save("weight2", weight2)
        np.save("weight3", weight3)
        np.save("weight4", weight4)
        np.save("bias1", bias1)
        np.save("bias2", bias2)
        np.save("bias3", bias3)
        np.save("bias4", bias4)
        # the extension is .npy
        print(weight1.shape)
        print(bias1.shape)
        lol=np.load("bias1.npy")
        print(lol.shape)
    

    Then initialize the Pytorch network with saved numpy arrays.

    import torch.nn as nn
    import numpy as np
    is_cuda=torch.cuda.is_available() class vgglownet(nn.Module): def __init__(self): super(vgglownet, self).__init__() self.conv1=nn.Conv2d( in_channels=3,out_channels=64,kernel_size=(3,3),stride=(1,1),padding=(1,1) ) self.conv2=nn.Conv2d( in_channels=64,out_channels=64,kernel_size=(3,3),stride=(1,1),padding=(1,1) ) self.maxpool1=nn.MaxPool2d( kernel_size=(2,2) ) # stride – the stride of the window. Default value is kernel_size self.conv3=nn.Conv2d( in_channels=64,out_channels=128,kernel_size=(3,3),stride=(1,1),padding=(1,1) ) self.conv4=nn.Conv2d( in_channels=128,out_channels=128,kernel_size=(3,3),stride=(1,1),padding=(1,1) ) self.maxpool2=nn.MaxPool2d( kernel_size=(2,2) ) def forward(self,inputs): x=self.conv1(inputs) x=self.conv2(x) x=self.maxpool1(x) x=self.conv3(x) x=self.conv4(x) x=self.maxpool2(x) return x if __name__ == '__main__': weight1 = np.load("weight1.npy") bias1 = np.load("bias1.npy") weight2 = np.load("weight2.npy") bias2 = np.load("bias2.npy") weight3 = np.load("weight3.npy") bias3 = np.load("bias3.npy") weight4 = np.load("weight4.npy") bias4 = np.load("bias4.npy")
    # weight1=weight1.transpose(0,3,1,2) if in tf order # weight2=weight2.transpose(0,3,1,2) # weight3=weight3.transpose(0,3,1,2) # weight4=weight4.transpose(0,3,1,2) net=vgglownet() net.conv1.weight = nn.Parameter(torch.from_numpy(weight1)) net.conv1.bias = nn.Parameter(torch.from_numpy(bias1)) net.conv2.weight = nn.Parameter(torch.from_numpy(weight2)) net.conv2.bias = nn.Parameter(torch.from_numpy(bias2)) net.conv3.weight = nn.Parameter(torch.from_numpy(weight3)) net.conv3.bias = nn.Parameter(torch.from_numpy(bias3)) net.conv4.weight = nn.Parameter(torch.from_numpy(weight4)) net.conv4.bias = nn.Parameter(torch.from_numpy(bias4)) torch.save(net,"trained_nets/vgglownetcpu.pt") torch.save(net.cuda(),"trained_nets/vgglownet.pt")

      

  • 相关阅读:
    Numpy存字符串
    一个类似于postman的协议测试工具
    freetds设置超时
    学习jQuery
    webpy 使用python3开发
    gdb调试coredump文件
    htop和ncdu
    rqalpha-自动量化交易系统(一)
    perl学习-运算符添加引号
    xss 和 csrf攻击详解
  • 原文地址:https://www.cnblogs.com/cxxszz/p/8583184.html
Copyright © 2011-2022 走看看