zoukankan      html  css  js  c++  java
  • fcn训练及预测tgs数据集

    一、背景

    kaggle上有这样一个题目,关于盐份预测的语义分割题目。TGS Salt Identification Challenge | Kaggle  https://www.kaggle.com/c/tgs-salt-identification-challenge


    二、过程

    1、下载数据,https://www.kaggle.com/c/tgs-salt-identification-challenge/data

    数据说明:

    train.csv
    id rle_mask
    4000项,即有4000张图片
    
    depths.csv
    id  z
    z(地下深度,英尺)
    22000项(为train和test图片张数总和)
    [50, 959]
    
    test
    18000张图片
    
    
    sample_submission.csv
    5f3b26ac68,1 2626 2628 100
    数据从1开始,行数和列数都要调整为从1开始,
    对于python来说,不需要转置,对于opencv来说要转置

    数据处理:

    (1)对于每张拍摄好的原始图片,有对应的深度信息,为了方便fcn训练,我们把深度信息也存入到图片中。可以用opencv将已标注好的原图,和带预测的原图的b通道保存原来的灰度信息,将g通道保存depth/256(整数倍),将r通道保存depth%256(取余数)。

    #include "opencv2/opencv.hpp"
    #include "opencv2/highgui/highgui.hpp"
    #include "opencv2/imgproc/imgproc.hpp"
    #include <iostream>
    #include <fstream>
    #include <vector>
    #include <map>
    #include <hash_map>
    
    int split(std::string str, std::string pattern, std::vector<std::string> &words)
    {
        std::string::size_type pos;
        std::string word;
        int num = 0;
        str += pattern;
        std::string::size_type size = str.size();
        for (auto i = 0; i < size; i++) {
            pos = str.find(pattern, i);
            if (pos == i) {
                continue;//if first string is pattern
            }
            if (pos < size) {
                word = str.substr(i, pos - i);
                words.push_back(word);
                i = pos + pattern.size() - 1;
                num++;
            }
        }
        return num;
    }
    
    int main(int argc, char** argv)
    {
        std::string input_path = "E:\kaggle\competition\tgs\test_images\";
        std::string out_path = "E:\kaggle\competition\tgs\test_images_out\";
        //read image name
        std::ifstream fin(input_path + "filelist.txt");
        std::string line;
        std::vector<std::string> image_names;
        while (!fin.eof()) {
            std::getline(fin, line);
            image_names.push_back(line);
            //std::cout << line << std::endl;
            //getchar();
        }
    
        std::ifstream depth_file("E:\kaggle\competition\tgs\depths.csv");
        std::getline(depth_file, line);
        //std::vector<int> depths;
        std::map<std::string, int> depths;
        while (!depth_file.eof()) {
            std::getline(depth_file, line);
            std::vector<std::string> words;
            int ret = split(line, ",", words);
            if (ret == 2) {
                //depths.push_back(atoi(words[1].c_str()));
                //std::cout << "depth " << depths[depths.size() - 1] << std::endl;
                //getchar();
                depths[words[0]] = atoi(words[1].c_str());
                //std::cout << words[0]  << ": " << depths[words[0]] << std::endl;
                //getchar();
            }
        }
        std::cout << "depths size = " << depths.size() << std::endl;
    
        //read image
        for (int i = 0; i < image_names.size(); ++i) {
            std::cout << "i = " << i << std::endl;
            cv::Mat src = cv::imread(input_path + image_names[i] + ".png", cv::IMREAD_GRAYSCALE);
            if (src.empty()) {
                return -1;
            }
            cv::Mat dst = cv::Mat::zeros(src.size(), CV_8UC3);
            for (int r = 0; r < dst.rows; ++r) {
                uchar *sptr = src.ptr<uchar>(r);
                cv::Vec3b *ptr = dst.ptr<cv::Vec3b>(r);
                for (int c = 0; c < dst.cols; ++c) {
                    ptr[c][0] = sptr[c];
                    ptr[c][1] = depths[image_names[i]] / 256;
                    ptr[c][2] = depths[image_names[i]] % 256;
                }
            }
            cv::imwrite(out_path + image_names[i] + ".png", dst);
        }
    
    
        std::cout << "finish!" << std::endl;
        getchar();
        return 0;
    }

    更改输入输出路径,对于已标注好的原图片用同样的方式处理。

    (2)因为我们的题目是二分类的语义分割,所以分割的结果的label只能是0和1,所以必须将masks中的图片为255更改为1。

    int main(int argc, char** argv)
    {
        //read image name
        std::ifstream fin("filelist.txt");
        std::string line;
        std::vector<std::string> image_names;
        while (!fin.eof()) {
            std::getline(fin, line);
            image_names.push_back(line);
        }
    
        //modify value
        for (int num = 0; num < image_names.size(); ++num) {
            cv::Mat src;
            src = cv::imread("masks/" + image_names[num], -1);
            if (src.empty()) {
                std::cout << "fail: " << num << image_names[num] << std::endl;
                getchar();
                return -1;
            }
    
            cv::Mat dst;
            src.convertTo(dst, CV_8UC1);
            for (int j = 0; j < dst.rows; ++j) {
                uchar *ptr = dst.ptr<uchar>(j);
                for (int i = 0; i < dst.cols; ++i) {
                    if (ptr[i] >= 128) {
                        ptr[i] = 1;
                    }
                }
            }
            cv::imwrite("out/" + image_names[num], dst);
        }
    }

    也可直接下载处理好的图片:链接:https://pan.baidu.com/s/1CAPIvQ6PayZ97eqeTpBcow 密码:h3t9

    2、下载语义分割的开源代码

    shelhamer/fcn.berkeleyvision.org: Fully Convolutional Networks for Semantic Segmentation by Jonathan Long*, Evan Shelhamer*, and Trevor Darrell. CVPR 2015 and PAMI 2016.  https://github.com/shelhamer/fcn.berkeleyvision.org

    3、下载修改好的代码

    https://github.com/litingpan/fcn

    4、将tgs-fcn32s、tgs-fcn16s、tgs-fcn8s复制到fcn.berkeleyvision.org文件夹中,将data/tgs复制到fcn.berkeleyvision.org/data文件夹中,将1中处理好的数据拷贝至tgs对应文件夹中。

    5、训练

    (1)fcn32s训练

    fcn.berkeleyvision.org	gs-fcn32s>python solve.py

    image

    (2)训练fcn16s

    fcn.berkeleyvision.org	gs-fcn16s>python solve.py

    image

    (3)训练fcn8s

    fcn.berkeleyvision.org	gs-fcn8s>python solve.py

    image

    可以看到经过32倍、16倍、8倍上采样最终达到overall accuracy(总体精度)为0.928,mean accuracy(平均精度)为0.887,mean IU(平均交并比)为0.827,fwavacc(带权重交并比)为0.866。

    6、预测

    我们用fcn8s训练好的模型进行预测。

    fcn.berkeleyvision.org	gs-fcn8s>python infers.py

    输出的结果在fcn.berkeleyvision.orgdata gspredictmasksout文件夹中,因为值是不是0就是1,所以感觉图片都是黑色的,如果想要可视化可以用opencv将1改为255,重新保存图片。

    7、将预测结果存到csv文件中。

    int main(int argc, char** argv)
    {
        //read image name
        std::ifstream fin("masksout/filelist.txt");
        std::string line;
        std::vector<std::string> image_names;
        while (!fin.eof()) {
            std::getline(fin, line);
            image_names.push_back(line);
        }
    
        std::ofstream fout("submission.csv");
        fout << "id,rle_mask" << std::endl;
        for (int k = 0; k < image_names.size(); ++k) {
            std::cout << "k = " << k << std::endl;
            cv::Mat src = cv::imread("masksout/" + image_names[k]);
    
            if (src.empty()) {
                return -1;
            }
    
            fout << image_names[k].substr(0, image_names[k].size()-4) << ",";
            cv::Mat gray;
            cv::cvtColor(src, gray, CV_BGR2GRAY);
    
            cv::Mat trans;
            cv::transpose(gray, trans);
    
            //fill hole
            cv::Mat tmp;
            cv::Mat hole;
            trans.convertTo(tmp, CV_8UC1, 255);
            dip::fillHole(tmp, hole);
    
    
            bool flag = false;
            int sum = 0;
            std::vector<int> list;
            int start_id = 0;
            for (int j = 0; j < src.rows; ++j) {
                uchar *ptr = hole.ptr<uchar>(j);
                for (int i = 0; i < src.cols; ++i) {
    
                    if (ptr[i] && !flag) {
                        flag = true;
                        start_id = j*gray.rows + i+1;
                        sum = 0;
                        sum++;
                    }
                    else if (ptr[i] && flag) {
                        sum++;
                    }
                    else if (!ptr[i] && flag){
                        flag = false;
                        list.push_back(start_id);
                        list.push_back(sum);
                        //std::cout << "start_id = " << start_id << ", " << "sum = " << sum << std::endl;
                        //getchar();
                    }
                }
            }//for j
            for (int n = 0; n < list.size(); ++n) {
                if (n == 0) {
                    fout << list[0];
                }
                else {
                    fout << " " << list[n];
                }
            }
            fout << std::endl;
            if (list.size() % 2 != 0) {
                std::cout << "error " << image_names[k] << std::endl;
            }
    
        }
        std::cout << "finish!" << std::endl;
        getchar();
        return 0;
    }

    其中fillHole函数为

    namespace dip {
    
        void fillHole(const cv::Mat &src, cv::Mat &dst)
        {
            cv::Size sz = src.size();
            cv::Mat tmp = cv::Mat::zeros(sz.height + 2, sz.width + 2, src.type());
            src.copyTo(tmp(cv::Range(1, sz.height + 1), cv::Range(1, sz.width + 1)));
            cv::floodFill(tmp, cv::Point(0, 0), cv::Scalar(255));
            cv::Mat cut;
            tmp(cv::Range(1, sz.height + 1), cv::Range(1, sz.width + 1)).copyTo(cut);
            dst = src | (~cut);
        }
    
    
    }

    8、提交结果

    360截图20180823233949918

    看来这个结果离比赛要求的答案还差很远。


    end

  • 相关阅读:
    IPv6时代,中小企业该如何布局?
    并发场景下的幂等问题——分布式锁详解
    阿里巴巴服务网格技术三位一体战略背后的思考与实践
    阿里云 Serverless 助力企业全面拥抱云原生
    阿里云徐立:面向容器和 Serverless Computing 的存储创新
    如何使用 Kubernetes 监测定位慢调用
    双11特刊 | 全面云原生化,数据库实例独共享混部 最高降低30%成本
    VS2010显示行号 po
    Webservice更新时出错。下载”。。。”时出错。请求失败,错误信息为: po
    google地图 无法定位 请在系统设置中启用“我的位置”源 po
  • 原文地址:https://www.cnblogs.com/smbx-ztbz/p/9569653.html
Copyright © 2011-2022 走看看