zoukankan      html  css  js  c++  java
  • 数据增强(每10度进行旋转,进行一次增强,然后对每张图片进行扩充10张patch,最后得到原始图片数*37*10数量的图片)

    # -*- coding: utf-8 -*-
    """
    Fourmi Editor

    This is a temporary script file.
    """
    import cv2
    import os
    import numpy as np
    import random
    import math


    def disOrdeImgs(Imgpath,Labelpath,orgTrainPath,orgTestPath,labelTrainPath,labelTestPath):
        if not os.path.exists(orgTrainPath):
            os.makedirs(orgTrainPath)
        if not os.path.exists(orgTestPath):
            os.makedirs(orgTestPath)
        if not os.path.exists(labelTrainPath):
            os.makedirs(labelTrainPath)
        if not os.path.exists(labelTestPath):
            os.makedirs(labelTestPath)

        count=0
        for fn in os.listdir(Imgpath): #fn 表示的是文件名
                count = count+1
        for index,v in enumerate(np.random.permutation(count)):    
            print('index:',index)
            print('v:',v)
            if index<=31911:
                OrgTrainPath=os.path.join(Imgpath,str(v)+'.jpg')
                Trainimg =cv2.imread(OrgTrainPath,0)     
                TrainPath=os.path.join(orgTrainPath,str(v)+'.jpg')
                cv2.imwrite(TrainPath,Trainimg)
            
                LabelTrainPath=os.path.join(Labelpath,str(v)+'.png')
                Trainlabel =cv2.imread(LabelTrainPath,0)
                TrainPath=os.path.join(labelTrainPath,str(v)+'.png')
                cv2.imwrite(TrainPath,Trainlabel)
            else:
                OrgTestPath=os.path.join(Imgpath,str(v)+'.jpg')
                Testimg =cv2.imread(OrgTestPath,0)
                TestPath=os.path.join(orgTestPath,str(v)+'.jpg')
                cv2.imwrite(TestPath,Testimg)
            
                LabelTestPath=os.path.join(Labelpath,str(v)+'.png')
                Testlabel =cv2.imread(LabelTestPath,0)
                TestPath=os.path.join(labelTestPath,str(v)+'.png')
                cv2.imwrite(TestPath,Testlabel)


    def extract_random(full_imgs,full_masks,patch_h,patch_w,N_patches):
        if(N_patches%(len(full_imgs))!=0):
            print("N_patches: please enter a multiple of 115")
            exit()
        patches=np.empty((N_patches,patch_h,patch_w))
        patches_masks = np.empty((N_patches,patch_h,patch_w))
        img_h=full_imgs[0].shape[0]
        img_w=full_imgs[0].shape[1]
        patch_per_img=int(N_patches/(full_imgs.shape[0]))
        print("patches per full image: "+str(patch_per_img))
        iter_tot=0
        for i in range(full_imgs.shape[0]):
            k=0
            while k<patch_per_img:
                x_center = random.randint(0+int(patch_w/2),img_w-int(patch_w/2))
                y_center = random.randint(0+int(patch_h/2),img_h-int(patch_h/2))
                patch=full_imgs[i][y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)]
                patch_mask=full_masks[i][y_center-int(patch_h/2):y_center+int(patch_h/2),x_center-int(patch_w/2):x_center+int(patch_w/2)]
                #print(patch_mask.shape)
                patches[iter_tot]=patch
                patches_masks[iter_tot]=patch_mask
                iter_tot+=1
                k+=1        
        return patches,patches_masks
        
        
    def imagePadding(img):
        img_h=img.shape[0]
        img_w=img.shape[1]
        scale=int(math.sqrt(img_h*img_h+img_w*img_w))
        scale=scale*2
        size=(int(scale),int(scale))
        out=cv2.resize(img,size,interpolation=cv2.INTER_AREA)
        return out

    def get_data(data_imgs_org,
                 data_groundTruth,
                 patch_height,
                 patch_width,
                 N_subimgs):
        imgs_org,imgs_groundTruth=ReadandProcessImage(data_imgs_org,data_groundTruth)
        print('imgs.shape',imgs_org.shape)
        print('imgs_groundTruth',imgs_groundTruth.shape)
        patches_imgs_train,patches_masks_train=extract_random(imgs_org,
                            imgs_groundTruth,patch_height,patch_width,N_subimgs)
        return patches_imgs_train,patches_masks_train


    def ReadandProcessImage(orgImgPath,groundTruthPath):
        images=[]
        labels=[]
        for root, dirs, files in os.walk(orgImgPath, topdown=False):
            for file in files:
                temp=file[:-4]
                ImgPath=os.path.join(root,file)
                LabelPath=os.path.join(groundTruthPath,temp+'.png')
                myimg=cv2.imread(ImgPath,0)
                mylabel=cv2.imread(LabelPath,0)
                print('ImgPath:',ImgPath)
                print('LabelPath:',LabelPath)
                #img=cv2.cvtColor(myimg,cv2.COLOR_BGR2GRAY)
                #mylabel=cv2.cvtColor(mylabel,cv2.COLOR_BGR2GRAY)
                assert(len(myimg.shape)==len(mylabel.shape))
                assert(myimg.shape[0]==mylabel.shape[0])
                assert(myimg.shape[1]==mylabel.shape[1])
                img=myimg
                #org_h=img.shape[0]
                #org_w=img.shape[1]
                img=cv2.equalizeHist(img)
                img=imagePadding(img)
                mylabel=imagePadding(mylabel)
                images.append(img)
                labels.append(mylabel)
            return np.array(images),np.array(labels)


    def roatate_img_label_to_file(imgPath,labelPath):
        global Iter
        Iter=1
        def rotateImg(img,label,orgHeight,orgWidth,imgPath,labelPath):
            global Iter
            (h,w)=img.shape
            center=(h/2,w/2)
            for i in range(360):
                if (i%10!=0):
                    continue
                M = cv2.getRotationMatrix2D(center, i, 1)
                imgRotated = cv2.warpAffine(img, M, (h, w))
                img0=imgRotated[int(center[0])-int(orgHeight/2):int(center[0])+int(orgHeight/2),
                    int(center[1])-int(orgWidth/2):int(center[1])+int(orgWidth/2)]
                labelRotated = cv2.warpAffine(label, M, (h, w))
                label0=labelRotated[int(center[0])-int(orgHeight/2):int(center[0])+int(orgHeight/2),
                    int(center[1])-int(orgWidth/2):int(center[1])+int(orgWidth/2)]
                path0=os.path.join(imgPath,str(Iter+115)+'.jpg')
                cv2.imwrite(path0,img0)
                path=os.path.join(labelPath,str(Iter+115)+'.png')
                cv2.imwrite(path,label0)
                Iter=Iter+1
                          
            print("ROTATW DONE!!!!")
        for root,dirs,files in os.walk(imgPath,topdown=False):
            for file in files:
                imgpath=os.path.join(root,file)
                temp=file[:-4]
                labelpath=os.path.join(labelPath,temp+'.png')
                img=cv2.imread(imgpath,0)
                label=cv2.imread(labelpath,0)
                print('imgpath:',imgpath)
                print('labelpath:',labelpath)
                print('imgshape:',img.shape)
                print('labelshape:',label.shape)
                assert(len(img.shape)==len(label.shape))
                assert(img.shape[0]==label.shape[0])
                assert(img.shape[1]==label.shape[1])
                org_h=img.shape[0]
                org_w=img.shape[1]
                img=imagePadding(img)
                label=imagePadding(label)
                print('imgPadding:',img.shape)
                print('labelPadding:',label.shape)
                rotateImg(img,label,org_h,org_w,imgPath,labelPath)
             

    data_train_imgs_org="/home/chendali1/Gsj/JX/Image/train/"
    data_test_imgs_org="/home/chendali1/Gsj/JX/Image/test/"
    data_train_grountTruth="/home/chendali1/Gsj/JX/GT/train/"
    data_test_grountTruth="/home/chendali1/Gsj/JX/GT/test/"

    patches_path_train='/home/chendali1/Gsj/JX/Patches/Org/train/'
    patches_path_test='/home/chendali1/Gsj/JX/Patches/Org/test/'
    patches_path_label_train='/home/chendali1/Gsj/JX/Patches/Label/train/'
    patches_path_label_test='/home/chendali1/Gsj/JX/Patches/Label/test/'

    #rotate_train_imgs_path="/home/chendali1/Gsj/JX/Image/train/"
    #rotate_test_imgs_path="/home/chendali1/Gsj/JX/Image/test/"
    #rotate_train_label_path="/home/chendali1/Gsj/JX/GT/train/"
    #rotate_test_label_path="/home/chendali1/Gsj/JX/GT/test/"
    """
    if not os.path.exists(patches_path_train):
        os.makedirs(patches_path_train)
    if not os.path.exists(patches_path_test):
        os.makedirs(patches_path_test)
    if not os.path.exists(patches_path_label_train):
        os.makedirs(patches_path_label_train)
    if not os.path.exists(patches_path_label_test):
        os.makedirs(patches_path_label_test)
    roatate_img_label_to_file(data_train_imgs_org,data_train_grountTruth)
    train_patches,train_groundTruth=get_data(data_train_imgs_org,data_train_grountTruth,224,224,37*115*10)
    for i in range(train_patches.shape[0]):
        b=np.zeros([train_patches.shape[1],train_patches.shape[2],3])
        b[:,:,0]=train_patches[i,:,:]
        b[:,:,1]=train_patches[i,:,:]
        b[:,:,2]=train_patches[i,:,:]
        cv2.imwrite(patches_path_train+str(i)+'.jpg',train_patches[i,:,:])
        cv2.imwrite(patches_path_label_train+str(i)+'.png',train_groundTruth[i,:,:])
    """

    Imgpath ="/home/chendali1/Gsj/JX/Patches/Org/train/"
    Labelpath="/home/chendali1/Gsj/JX/Patches/Label/train/"

    orgTrainPath="/home/chendali1/Gsj/DRIVE/images/training/"
    orgTestPath="/home/chendali1/Gsj/DRIVE/images/validation/"
    labelTrainPath="/home/chendali1/Gsj/DRIVE/annotations/training/"
    labelTestPath="/home/chendali1/Gsj/DRIVE/annotations/validation/"

    disOrdeImgs(Imgpath,Labelpath,orgTrainPath,orgTestPath,labelTrainPath,labelTestPath)


  • 相关阅读:
    Wannafly Winter Camp 2020 Day 7D 方阵的行列式
    [CF1311F] Moving Points
    [CF1311E] Construct the Binary Tree
    [CF1311D] Three Integers
    [CF1311C] Perform the Combo
    [CF1311B] WeirdSort
    [CF1311A] Add Odd or Subtract Even
    Wannafly Winter Camp 2020 Day 7A 序列
    SP7258 SUBLEX
    Wannafly Winter Camp 2020 Day 6J K重排列
  • 原文地址:https://www.cnblogs.com/fourmi/p/9079368.html
Copyright © 2011-2022 走看看