zoukankan      html  css  js  c++  java
  • 奉献pytorch 搭建 CNN 卷积神经网络训练图像识别的模型,配合numpy 和matplotlib 一起使用调用 cuda GPU进行加速训练

    1、Torch构建简单的模型

    # coding:utf-8
    import torch
    
    class Net(torch.nn.Module):
        def __init__(self,img_rgb=3,img_size=32,img_class=13):
            super(Net, self).__init__()
            self.conv1 = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=img_rgb, out_channels=img_size, kernel_size=3, stride=1,padding= 1), #
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(2),
                # torch.nn.Dropout(0.5)
            )
            self.conv2 = torch.nn.Sequential(
                torch.nn.Conv2d(28, 64, 3, 1, 1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(2)
            )
            self.conv3 = torch.nn.Sequential(
                torch.nn.Conv2d(64, 64, 3, 1, 1),
                torch.nn.ReLU(),
                torch.nn.MaxPool2d(2)
            )
            self.dense = torch.nn.Sequential(
                torch.nn.Linear(64 * 3 * 3, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, img_class)
            )
    
        def forward(self, x):
            conv1_out = self.conv1(x)
            conv2_out = self.conv2(conv1_out)
            conv3_out = self.conv3(conv2_out)
            res = conv3_out.view(conv3_out.size(0), -1)
            out = self.dense(res)
            return out
    
    CUDA = torch.cuda.is_available()
    
    model = Net(1,28,13)
    print(model)
    
    optimizer = torch.optim.Adam(model.parameters())
    loss_func = torch.nn.MultiLabelSoftMarginLoss()#nn.CrossEntropyLoss()
    
    if CUDA:
        model.cuda()
    
    def batch_training_data(x_train,y_train,batch_size,i):
        n = len(x_train)
        left_limit = batch_size*i
        right_limit = left_limit+batch_size
        if n>=right_limit:
            return x_train[left_limit:right_limit,:,:,:],y_train[left_limit:right_limit,:]
        else:
            return x_train[left_limit:, :, :, :], y_train[left_limit:, :]
    

      

    2、奉献训练过程的代码

    #  coding:utf-8
    import time
    import os
    import torch
    import numpy as np
    from data_processing import get_DS
    from CNN_nework_model import cnn_face_discern_model
    from torch.autograd import Variable
    from use_torch_creation_model import optimizer, model, loss_func, batch_training_data,CUDA
    from sklearn.metrics import accuracy_score
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    st = time.time()
    # 获取训练集与测试集以 8:2 分割
    x_,y_,y_true,label = get_DS()
    
    label_number = len(label)
    
    x_train,y_train = x_[:960,:,:,:].reshape((960,1,28,28)),y_[:960,:]
    
    x_test,y_test = x_[960:,:,:,:].reshape((340,1,28,28)),y_[960:,:]
    
    y_test_label = y_true[960:]
    
    print(time.time() - st)
    print(x_train.shape,x_test.shape)
    
    batch_size = 100
    n = int(len(x_train)/batch_size)+1
    
    
    for epoch in range(100):
        global loss
        for batch in range(n):
            x_training,y_training = batch_training_data(x_train,y_train,batch_size,batch)
            batch_x,batch_y = Variable(torch.from_numpy(x_training)).float(),Variable(torch.from_numpy(y_training)).float()
            if CUDA:
                batch_x=batch_x.cuda()
                batch_y=batch_y.cuda()
    
            out = model(batch_x)
            loss = loss_func(out, batch_y)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        # 测试精确度
        if epoch%9==0:
            global x_test_tst
            if CUDA:
                x_test_tst = Variable(torch.from_numpy(x_test)).float().cuda()
            y_pred = model(x_test_tst)
    
            y_predict = np.argmax(y_pred.cpu().data.numpy(),axis=1)
    
            acc = accuracy_score(y_test_label,y_predict)
    
            print("loss={} aucc={}".format(loss.cpu().data.numpy(),acc))
    

      

    3、总结

           通过博主通过TensorFlow、keras、pytorch进行训练同样的模型同样的图像数据,结果发现,pyTorch快了很多倍,特别是在导入模型的时候比TensorFlow快了很多。合适部署接口和集成在项目中。

  • 相关阅读:
    使用静态工厂方法的好处和坏处
    xUtils3源码分析(一):view的绑定
    在laravel之外使用eloquent
    ruby里面的毒瘤
    ruby的代码风格
    ruby里面的属性访问器
    ruby里面module和class的区别
    unity里面查找所有物体
    android studio安装须知
    intellij系列ide配置
  • 原文地址:https://www.cnblogs.com/wuzaipei/p/11406450.html
Copyright © 2011-2022 走看看