zoukankan      html  css  js  c++  java
  • c++实现unet

    #include<torch/torch.h>
    #include<iostream>
    #include<vector>
    #include<cassert>
    #include<stdlib.h>
    #include<unordered_map>
    #include<fstream>
    class double_conv:public torch::nn::Module
    {
        public:
            torch::nn::Conv2d conv1,conv2;
            torch::nn::BatchNorm bn1,bn2;
            int in_ch,out_ch;
        public:
            double_conv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv1(torch::nn::Conv2dOptions(in_ch,out_ch,3).padding(1)),bn1(out_ch),
                                           conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch)
            {
                register_module("conv1",conv1);
                register_module("conv2",conv2);
                register_module("bn1",bn1);
                register_module("bn2",bn2);
            }
            torch::Tensor forward(torch::Tensor x)
            {
                x = conv1->forward(x);
                x = bn1->forward(x);
                x = torch::relu(x);
                x = conv2->forward(x);
                x = bn2->forward(x);
                x = torch::relu(x);
                return x;
            }
    };
    class inconv:public torch::nn::Module
    {
        public:
            int in_ch,out_ch;
        public:
            inconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){}
            torch::Tensor forward(torch::Tensor x)
            {
                 double_conv dc(in_ch,out_ch);
                 x = dc.forward(x);
                 return x;
            }
    };
    class down:public torch::nn::Module
    {
        public:
            int in_ch,out_ch;
        public:
            down(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch){}
            torch::Tensor forward(torch::Tensor x)
            {
                x = torch::max_pool2d(x,2);
                double_conv dc(in_ch,out_ch);
                x = dc.forward(x);
                return x;
            }
    };
    class up:public torch::nn::Module
    {
        public:
            int in_ch,out_ch;
            torch::nn::Conv2d upconv;
            torch::nn::Conv2d conv1,conv2;
            torch::nn::BatchNorm bn1,bn2;
            torch::Tensor x;
        public:
            up(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),upconv(torch::nn::Conv2dOptions(in_ch,out_ch,4).padding(1).stride(2).transposed(new bool(true))),
                                     conv1(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn1(out_ch),conv2(torch::nn::Conv2dOptions(out_ch,out_ch,3).padding(1)),bn2(out_ch)
            {
                register_module("upconv",upconv);
                register_module("conv1",conv2);
                register_module("conv2",conv2);
                register_module("bn1",bn1);
                register_module("bn2",bn2);
            }
            torch::Tensor forward(torch::Tensor x1,torch::Tensor x2)
            {
                x = upconv->forward(x1);
                x = torch::cat({x,x2},1);
                double_conv dc(x.size(1),out_ch);
                x = dc.forward(x);
                //x = conv1->forward(x);
                //x = bn1->forward(x);
                //x = torch::relu(x);
                //x = conv2->forward(x);
                //x = bn2->forward(x);
                //x = torch::relu(x);
                return x;
            }
    };
    class outconv:public torch::nn::Module
    {
        public:
            int in_ch,out_ch;
            torch::nn::Conv2d conv;
        public:
            outconv(int in_ch,int out_ch):in_ch(in_ch),out_ch(out_ch),conv(torch::nn::Conv2dOptions(in_ch,out_ch,1).padding(0))
            {
                register_module("conv",conv);
            }
            torch::Tensor forward(torch::Tensor x)
            {
                return conv->forward(x);
            }
    };
    class unet:public torch::nn::Module
    {
        public:
            int n_ch,n_class;
            inconv *iconv= new inconv(n_ch,64);
            down *down1= new down(64,256);
            down *down2= new down(256,512);
            down *down3= new down(512,512);
            down *down4= new down(512,512);
            up *up1= new up(512,256);
            up *up2= new up(256,128);
            up *up3= new up(128,64);
            up *up4= new up(64,64);
            outconv *oconv= new outconv(64,n_class);
            torch::Tensor x1,x2,x3,x4,x5;
        public:
            unet(int n_ch,int n_class):n_ch(n_ch),n_class(n_class){}
            torch::Tensor forward(torch::Tensor x)
            {
               x1 = iconv->forward(x);
               x2 = down1->forward(x1);
               x3 = down2->forward(x2);
               x4 = down3->forward(x3);
               x5 = down4->forward(x4);
               x = up1->forward(x5,x4);
               x = up2->forward(x,x3);
               x = up3->forward(x,x2);
               x = up4->forward(x,x1);
               x = oconv->forward(x);
               return x;
            }
    };
    std::vector<float> Tokenize(const std::string& str,const std::string& delimiters)
    {
        std::vector<float> tokens;
        std::string::size_type lastPos = str.find_first_not_of(delimiters, 0);
        std::string::size_type pos     = str.find_first_of(delimiters, lastPos);
        while (std::string::npos != pos || std::string::npos != lastPos)
        {
            tokens.push_back(std::atof(str.substr(lastPos, pos - lastPos).c_str()));
            lastPos = str.find_first_not_of(delimiters, pos);
            pos = str.find_first_of(delimiters, lastPos);
        }
        return tokens;
    }
    std::vector<std::vector<float>> readTxt(std::string file)
    {
        std::ifstream infile;
        infile.open(file.data());
        assert(infile.is_open());
        std::string s;
        std::vector<float> vec;
        std::vector<std::vector<float>> res;
        while(getline(infile,s))
        {
            std::string tt= static_cast<std::string>(s);
            vec = Tokenize(tt, " ");
            res.push_back(vec);
        }
        infile.close();
        std::cout<<"gdood"<<std::endl;
        return res;
    }
    torch::Tensor float2TensorLabel()
    {
        static float tt[2478][3125]={0};
        //memset(tt,0,sizeof(tt));
        std::vector<std::vector<float>> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/LabelData.txt");
        int ch = vec.size();
        int len = vec[0].size();
        for(int i=0;i<ch;i++)
        {
            for(int j=0;j<len;j++)
            {
                tt[i][j]=vec[i][j];
            }
        }
        torch::Tensor tmask = torch::CPU(torch::kFloat).tensorFromBlob(tt,{2478,3125});
        return tmask;
    }
    torch::Tensor float2TensorData()
    {
        static float tt[7][2478*3125] = {0};
        std::vector<std::vector<float>> vec = readTxt("/Users/yanlang/unet/mx-unet/U-Net/ImageData.txt");
        int ch = vec.size();
        int len = vec[0].size();
        for(int i=0;i<ch;i++)
        {
            for(int j=0;j<len;j++)
            {
                tt[i][j]=vec[i][j];
            }
        }
        torch::Tensor tdata = torch::CPU(torch::kFloat).tensorFromBlob(tt,{7,2478,3125});
        return tdata;
    }
    int imgH=256;
    int imgW=256;
    torch::Tensor RandData(torch::Tensor data,int hight,int width)
    {
        //torch::Tensor datat = torch::squeeze(data);
        torch::Tensor tmp = torch::zeros({7,imgH,imgW});
        for(int i=hight;i<hight+imgH;i++)
        {
            for(int j=width;j<width+imgW;j++)
            {
                for(int k=0;k<7;k++)
                {
                    tmp[k][i-hight][j-width]=data[k][i][j];
                }
            }
        }
        return tmp;
    }
    torch::Tensor RandMask(torch::Tensor label, int hight,int width)
    {
        torch::Tensor tmp = torch::zeros({imgH,imgW});
        for(int i=hight;i<hight+imgH;i++)
        {
            for(int j=width;j<width+imgW;j++)
            {
                tmp[i-hight][j-width]=label[i][j];
            }
        }
       return tmp;
    }
    std::vector<torch::Tensor> DataLoader(torch::Tensor data,torch::Tensor label,int batch_size)
    {
        int imghight = data.size(1);
        int imgwidth = data.size(2);
        int randhight,randwidth;
        torch::Tensor resdata = torch::zeros({batch_size,7,imgH,imgW});
        torch::Tensor reslabel = torch::zeros({batch_size,imgH,imgW});
        for(int i=0;i<batch_size;i++)
        {
            randhight = rand()%(imghight-imgH-1);
            randwidth = rand()%(imgwidth-imgW-1);
            resdata[i] = RandData(data,randhight,randwidth);
            reslabel[i] = RandMask(label,randhight,randwidth);
        }
        return {resdata,reslabel};
    }
    torch::autograd::Variable Get_predData(torch::autograd::Variable data)
    {
        //torch::autograd::Variable datat = torch::unsqueeze(data,0);
        torch::autograd::Variable tmp = torch::zeros({7,imgH,imgW});
        for(int i=500;i<756;i++)
        {
            for(int j=500;j<756;j++)
            {
                for(int k=0;k<7;k++)
                {
                    tmp[k][i-500][j-500]=data[k][i][j];
                }
            }
        }
        return torch::unsqueeze(tmp,0);
    }
    void write2Txt(torch::autograd::Variable data)
    {
        std::ofstream fout("tresult.txt");
        for(int i=0;i<data.size(0);i++)
        {
            for(int j=0;j<data.size(1);j++)
            {
                fout<<data[i][j]<<std::endl;
            }
        }
        fout.close();
    }
    void saveModel(std::vector<torch::Tensor> weights,std::vector<std::string> key)
    {
        std::ofstream fout("unet.txt");
        //std::unordered_map<std::string,torch::Tensor> mp;
        for(int i=0;i<weights.size();i++)
        {
            fout<<key[i]<<std::endl;
            fout<<weights[i]<<std::endl;
        }
        fout.close();
    }
    void trainConvNet(unet model)
    {
        torch::optim::SGD optimizer(model.parameters(),/*lr=*/0.01);
        torch::Tensor pred;
        std::cout<<"load data ......"<<std::endl;
        torch::autograd::Variable data = torch::autograd::make_variable(float2TensorData());
        torch::autograd::Variable label = torch::autograd::make_variable(float2TensorLabel());
        std::cout<<"done!!"<<std::endl;
        torch::Tensor train_data,train_label;
        std::vector<torch::Tensor> vecdata;
        for(int epoch=0;epoch<20;epoch++)
        {
            vecdata = DataLoader(data,label,2);
            std::cout<<"vecdata after done!!"<<std::endl;
            train_data = vecdata[0];
            std::cout<<"train_data after done"<<std::endl;
            train_label = vecdata[1];
            std::cout<<train_label.size(0)<<std::endl;
            std::cout<<"train_label after done"<<std::endl;
            pred = model.forward(train_data);
            auto loss = torch::nll_loss2d(pred,torch::_cast_Long(train_label));//torch::_cast_Long()
            std::cout<<"the loss is"<<loss<<std::endl;
            optimizer.zero_grad();
            loss.backward();
            optimizer.step();
        }
        std::vector<torch::Tensor> vecValue;
        std::vector<std::string> vecKey;
        torch::nn::ParameterCursor tt = model.parameters();
        for(auto it=tt.begin();it!=tt.end();it++)
        {
            vecValue.push_back((*it).value);
            vecKey.push_back((*it).key);
        }
        saveModel(vecValue,vecKey);
        torch::autograd::Variable predData = Get_predData(data);
        torch::autograd::Variable fl = model.forward(predData);
        torch::autograd::Variable result = torch::squeeze(fl);
        torch::autograd::Variable rt = result.argmax(0);
        std::cout<<rt.size(0)<<std::endl;
        std::cout<<rt.size(1)<<std::endl;
        write2Txt(rt);
    }
    int main()
    {
        unet net(7,2);
        trainConvNet(net);
        return 0;
    }
  • 相关阅读:
    C#中静态与非静态方法比较
    Hibernate学习之路-- -映射 继承关系(subclass , joined-subclass,union-subclass )
    网络协议概述:物理层、连接层、网络层、传输层、应用层详解
    phpstorm xdebug配置
    eclipse修改内存大小
    Session机制详解
    java把html标签字符转普通字符(反转换成html标签)(摘抄)
    JAVA调用WCF
    RabbitMQ入门与使用篇
    大话程序猿眼里的高并发
  • 原文地址:https://www.cnblogs.com/semen/p/9778300.html
Copyright © 2011-2022 走看看