zoukankan      html  css  js  c++  java
  • How to Initialize Neural Networks in PyTorch with Pretrained Nets in TensorFlow or Theano

    First convert network weights and biases to numpy arrays. Note if you want to load a pre-trained network with Keras, you must define it of the same network structure with Keras.

    Note which backend of Keras you use. I install Keras with TensorFlow backend but the VGGnet I'm going to replicate has Theano order.

    from keras.models import Sequential
    from keras.layers.core import Flatten, Dense, Dropout
    from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D
    from keras.optimizers import SGD
    import numpy as np
    
    def VGG_16(weights_path=None):
        model = Sequential()
        model.add(ZeroPadding2D((1,1),input_shape=(3,224,224),dim_ordering="th"))
        model.add(Convolution2D(64, 3, 3, activation='relu',dim_ordering="th",name="conv1"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(64, 3, 3, activation='relu',dim_ordering="th",name="conv2"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(128, 3, 3, activation='relu',dim_ordering="th",name="conv3"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(128, 3, 3, activation='relu',dim_ordering="th",name="conv4"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(256, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(256, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(256, 3, 3, activation='relu',dim_ordering="th"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(ZeroPadding2D((1,1),dim_ordering="th"))
        model.add(Convolution2D(512, 3, 3, activation='relu',dim_ordering="th"))
        model.add(MaxPooling2D((2,2), strides=(2,2),dim_ordering="th"))
    
        model.add(Flatten())
        model.add(Dense(4096, activation='relu'))
        model.add(Dropout(0.5))
        model.add(Dense(4096, activation='relu'))
        model.add(Dropout(0.5))
        model.add(Dense(1000, activation='softmax'))
    
        if weights_path:
            model.load_weights(weights_path)
    
        return model
    
    
    if __name__ == '__main__':
        weights_path="vgg16_weights.h5"
        model=VGG_16(weights_path=weights_path)
        weight1,bias1=model.get_layer(name="conv1").get_weights()
        weight2,bias2=model.get_layer(name="conv2").get_weights()
        weight3,bias3=model.get_layer(name="conv3").get_weights()
        weight4,bias4=model.get_layer(name="conv4").get_weights()
        np.save("weight1", weight1)
        np.save("weight2", weight2)
        np.save("weight3", weight3)
        np.save("weight4", weight4)
        np.save("bias1", bias1)
        np.save("bias2", bias2)
        np.save("bias3", bias3)
        np.save("bias4", bias4)
        # the extension is .npy
        print(weight1.shape)
        print(bias1.shape)
        lol=np.load("bias1.npy")
        print(lol.shape)
    

    Then initialize the Pytorch network with saved numpy arrays.

    import torch.nn as nn
    import numpy as np
    is_cuda=torch.cuda.is_available() class vgglownet(nn.Module): def __init__(self): super(vgglownet, self).__init__() self.conv1=nn.Conv2d( in_channels=3,out_channels=64,kernel_size=(3,3),stride=(1,1),padding=(1,1) ) self.conv2=nn.Conv2d( in_channels=64,out_channels=64,kernel_size=(3,3),stride=(1,1),padding=(1,1) ) self.maxpool1=nn.MaxPool2d( kernel_size=(2,2) ) # stride – the stride of the window. Default value is kernel_size self.conv3=nn.Conv2d( in_channels=64,out_channels=128,kernel_size=(3,3),stride=(1,1),padding=(1,1) ) self.conv4=nn.Conv2d( in_channels=128,out_channels=128,kernel_size=(3,3),stride=(1,1),padding=(1,1) ) self.maxpool2=nn.MaxPool2d( kernel_size=(2,2) ) def forward(self,inputs): x=self.conv1(inputs) x=self.conv2(x) x=self.maxpool1(x) x=self.conv3(x) x=self.conv4(x) x=self.maxpool2(x) return x if __name__ == '__main__': weight1 = np.load("weight1.npy") bias1 = np.load("bias1.npy") weight2 = np.load("weight2.npy") bias2 = np.load("bias2.npy") weight3 = np.load("weight3.npy") bias3 = np.load("bias3.npy") weight4 = np.load("weight4.npy") bias4 = np.load("bias4.npy")
    # weight1=weight1.transpose(0,3,1,2) if in tf order # weight2=weight2.transpose(0,3,1,2) # weight3=weight3.transpose(0,3,1,2) # weight4=weight4.transpose(0,3,1,2) net=vgglownet() net.conv1.weight = nn.Parameter(torch.from_numpy(weight1)) net.conv1.bias = nn.Parameter(torch.from_numpy(bias1)) net.conv2.weight = nn.Parameter(torch.from_numpy(weight2)) net.conv2.bias = nn.Parameter(torch.from_numpy(bias2)) net.conv3.weight = nn.Parameter(torch.from_numpy(weight3)) net.conv3.bias = nn.Parameter(torch.from_numpy(bias3)) net.conv4.weight = nn.Parameter(torch.from_numpy(weight4)) net.conv4.bias = nn.Parameter(torch.from_numpy(bias4)) torch.save(net,"trained_nets/vgglownetcpu.pt") torch.save(net.cuda(),"trained_nets/vgglownet.pt")

      

  • 相关阅读:
    组队开发最后冲刺周第一次会议
    android 本地数据库sqlite的封装
    java 空指针异常造成的原因有哪些
    jsp usebean的使用
    PHP模拟登录并获取数据
    php rsa加密解密实例
    30个php操作redis常用方法代码例子
    官方微信接口(全接口)
    curl类封装
    网站微信登录
  • 原文地址:https://www.cnblogs.com/cxxszz/p/8583184.html
Copyright © 2011-2022 走看看