zoukankan      html  css  js  c++  java
  • 2021.4.7 星期三

    from data.dataloader_detection import ListDataset
    from data.collate_function import collate_function
    from config.config import load_config, cfg
    from net.resnet import ResNet, ResnetBasic,ResnetBasicSlim
    from utils.nms_np_simple import nms
    from data.dataloader_test_detection import ListDataset as testListDataset
    from data.resize_uniform import resizeUniform
    
    from dataY.yolov1_dataY import DataY
    from net.yolov1 import YOLOv1
    import numpy as np
    import cv2
    from loss.yololoss import yoloLoss
    from dataY.yolov1_dataY import DataY
    import time
    
    import torch
    def plotGride(img, grideHW=(7,7), stride=64):
        # plot gride
        h,w,_ = img.shape
        for li in range(grideHW[1]):
            cv2.line(img, (int(li *stride), 0),
                     (int(li * stride), int(w)), (0, 255, 0), 1)
        for li in range(grideHW[0]):
            cv2.line(img, (0, int(li * stride)),
                     (int(h), int(li * stride)), (0, 255, 0), 1)
        return img
    
    def plotBox(img, x1, y1,x2,y2,txts):
        cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 1)
        for i in range(len(txts)):
            cv2.putText(img, txts[i], (int(x1), int(y1 + i*15)), 1, 1, (0, 0, 255))
        cv2.circle(img, (int((x1 +x2) / 2), int((y1 + y2) / 2)), color=(0, 0, 255), radius=2, thickness=-1)
        return img
    
    def post(pred, dic):
        # target contain object mask,
        b, h, w, c = pred.size()
    
        alldets= []# [[一个图片的dets],[],...]
        for i in range(b):
            apred = pred[i]
    
            """根据置信度来判断哪个网络有"""
            coobjMask1 = apred[:,:,4] > dic["scoreThresh"]
            for j in range(dic["bboxPredNum"]):
                coobjMask1 = torch.logical_or(coobjMask1, apred[:, :, 4 + i * 5] > dic["scoreThresh"])
    
            """每个cell 选择最大置信度的哪个"""
            confPred = apred[:, :, 4:5 * dic["bboxPredNum"]:5]
            _, maxIndex = torch.max(confPred, dim=-1, keepdim=True)
            h_, w_, c_ = apred.size()
            oneHot = torch.zeros(h_, w_, dic["bboxPredNum"]).to("cuda:0").scatter_(-1, maxIndex, 1).type(
                torch.bool)
    
            """最终选择"""
            coobjMask = torch.logical_and(oneHot, torch.unsqueeze(coobjMask1, -1))
            chioceIndex = torch.nonzero(coobjMask, as_tuple=False)
            chioceIndex = chioceIndex.to('cpu').numpy()
    
            dets = []
            for k in range(chioceIndex.shape[0]):
                hid, wid, boxid = chioceIndex[k][0], chioceIndex[k][1], chioceIndex[k][2]
                fullVector = apred[hid][wid].to("cpu").numpy()
                clsVec = fullVector[-dic["clsNum"]:]
                boxVector = fullVector[boxid * 5: (boxid + 1) * 5]
                deltax, deltay, w, h, score = boxVector[0], boxVector[1], boxVector[2], boxVector[3], boxVector[4]
    
                cy = (hid + deltay) * dic["stride"]
                cx = (wid + deltax) * dic["stride"]
                w = w * w * dic["netInputHw"][1]
                h = h * h * dic["netInputHw"][0]
                c = clsVec.argmax()
                dets.append([cx, cy, w, h, score, c])
            dets = np.array(dets)
            dets = dets.reshape(-1, 6)
            dets[:, :2] -= dets[:, 2:4] / 2
            dets[:, 2:4] += dets[:, :2]
    
            #boarder in the image
            dets[:,:2] = np.where(dets[:, :2] < 0, 0, dets[:,:2] )
            dets[:, 2] = np.where(dets[:, 2] > dic["netInputHw"][1], dic["netInputHw"][1], dets[:, 2])
            dets[:, 3] = np.where(dets[:, 3] > dic["netInputHw"][0], dic["netInputHw"][0], dets[:, 3])
    
            dets = nms(dets, dic["iouThresh"])
            alldets.append(dets)
        return alldets
    
    if __name__ == '__main__':
        """config"""
        load_config(cfg, "./config/config.yaml")
        print(cfg)
        device = torch.device('cuda:0')
    
        batchsize = 2
        showFlag = 1
        saveFlag = 0
        saveDir = ""
    
        mode = 1
    
        # 如果是test mode 没有label 信息, 只能是预测
        # cam Mode 调用摄像头
        # val model这个有label信息,在show的时候会展示label
        modeDict = {0:"testMode", 1:"valMode", 2:"camMode"}
        postDict = {"scoreThresh": 0.3,
                    "iouThresh": 0.3,
                    "netInputHw":(448,448),
                    "bboxPredNum": cfg.model.bboxPredNum,
                    "clsNum":cfg.model.clsNum,
                    "stride":cfg.model.stride,
                    }
    
        """dataset"""
        if modeDict[mode] == "testMode":
            dataset =  testListDataset( imgPath = cfg.dir.testImgDir,  # images root /
                             netInputSizehw = cfg.model.netInput,
                             imgChannelNumber=cfg.model.imgChannelNumber,
                             clsname= cfg.clsname
                             )
        if modeDict[mode] == "valMode":
            dataset = ListDataset(trainAnnoPath=cfg.dir.valAnnoDir, trainImgPath=cfg.dir.valImgDir,
                                    netInputSizehw=cfg.model.netInput, augFlag=False,
                                     clsname = cfg.clsname, imgChannelNumber=cfg.model.imgChannelNumber)
        if modeDict[mode] == "valMode" or modeDict[mode] == "testMode":
            dataLoader = torch.utils.data.DataLoader(
                dataset,
                collate_fn=collate_function,
                batch_size= batchsize, #cfg.train.batchSize,
                shuffle= False,
                num_workers=cfg.train.workers,
                pin_memory=True,  # 如果机器计算能力好的话,就可以设置为True,
            )
    
        datay = DataY(inputHW=cfg.model.netInput,  # 指定了inputsize 这是因为输入的是经过resize后的图片
                      gride=cfg.model.featSize,  # 将网络输入成了多少个网格
                      stride=cfg.model.stride,
                      boxNum=cfg.model.bboxPredNum,
                      clsNum=cfg.model.clsNum)
    
        """准备网络"""
        # network = ResNet(ResnetBasic, [2, 2, 2, 2], channel_out = 15)
        network = ResNet(ResnetBasicSlim,
                         [2, 2, 2, 2],
                         # [3,4,6,3],
                         channel_in=cfg.data.imgChannelNumber,
                         channel_out=(cfg.model.bboxPredNum * 5 + cfg.model.clsNum))
        # network = YOLOv1(params={"dropout": 0.5, "num_class": cfg.model.clsNum})
        network.to(device)
        weights = torch.load(cfg.dir.modelSaveDir + cfg.dir.modelName)  # 加载参数
        network.load_state_dict(weights)  # 给自己的模型加载参数
    
        with torch.no_grad():
            if modeDict[mode] == "camMode":
                cap = cv2.VideoCapture(0)
    
                while (1):
                    assert cfg.model.imgChannelNumber == 3, "输入通道目前支支持3通道"
                    ret, img = cap.read()
                    frame = np.copy(img)
                    if not ret:
                        print("can not cap a picture")
                        time.sleep(1)
                        continue
                    img, infos = resizeUniform(img, cfg.model.netInput)
    
                    imgs = np.array([img])
                    imgs = torch.from_numpy(imgs.astype(np.float32)),
                    imgs = imgs.to(device).float()
                    mean = torch.tensor(cfg.data.normalize[0]).cuda().reshape(3, 1, 1)
                    std = torch.tensor(cfg.data.normalize[1]).cuda().reshape(3, 1, 1)
                    imgs = (imgs - mean) / std
    
                    pred = network(imgs)
                    pred = pred.permute(0, 2, 3, 1)  # rotate the BCHW to BHWC
    
                    """post"""
                    bcDets = post(pred, postDict)
                    dets = bcDets[0]
    
                    cv2.imshow("capture", img)
                    """plot pred"""
                    imgp = plotGride(frame, grideHW=(cfg.model.featSize), stride=cfg.model.stride)
                    for i in range(dets.shape[0]):
                        x1, y1, x2, y2, score, cls = 
                            dets[i][0], dets[i][1], dets[i][2], dets[i][3], dets[i][4], dets[i][5]
                        imgp = plotBox(imgp, x1, y1, x2, y2,
                                       ["s: " + str(round(score, 3)), "c: " + cfg.clsname[cls]])
                    cv2.imshow("pred", imgp)
                    cv2.waitKey()
                    if cv2.waitKey(1) & 0xFF == 32:
                        break
    
    
            if modeDict[mode] == "valMode" or modeDict[mode] == "testMode":
                for id, infos in enumerate(dataLoader): #每个batch
                    """forward and pred"""
                    imgs = infos['images']
                    imgs = imgs.to(device).float()
                    mean = torch.tensor(cfg.data.normalize[0]).cuda().reshape(3, 1, 1)
                    std = torch.tensor(cfg.data.normalize[1]).cuda().reshape(3, 1, 1)
                    imgs = (imgs - mean) / std
    
                    pred = network(imgs)
                    pred = pred.permute(0, 2, 3, 1)  # rotate the BCHW to BHWC
    
                    """post"""
                    bcDets = post(pred, postDict)
    
                    for bcid in range(batchsize):
                        dets = bcDets[bcid]
    
                        if showFlag:
                            image = infos['images'][bcid]
                            image = image.to('cpu').numpy()
                            image = image.transpose(1, 2, 0).astype(np.uint8)
                            image = cv2.UMat(image).get()
                            imgt = np.copy(image)
                            imgp = np.copy(image)
    
                            """plot pred"""
                            imgp = plotGride(imgp, grideHW=(cfg.model.featSize), stride=cfg.model.stride)
                            for i in range(dets.shape[0]):
                                x1, y1, x2, y2, score, cls = 
                                    dets[i][0], dets[i][1], dets[i][2], dets[i][3], dets[i][4], dets[i][5]
                                imgp = plotBox(imgp, x1, y1, x2, y2,
                                               [str(i)+" s: " + str(round(score, 3)), "c: " + cfg.clsname[cls]])
                                print("pred[%d x1:%.2f y1:%.2f w:%.2f h:%.2f score: %.2f "%(i, x1,y1,x2-x1,y2-y1,score)+ cfg.clsname[cls] +"]")
                            cv2.imshow("pred", imgp)
                            cv2.waitKey()
    
                        if showFlag and not modeDict[mode] == "testModeFlag":# test mode 没有target
                            """read target"""
                            bboxesGt = infos['bboxesGt'][bcid]
                            classesGt = infos['classes'][bcid]
                            annoName = infos["annoName"][bcid]
    
                            """plot target"""
                            imgt = plotGride(imgt, grideHW=(cfg.model.featSize), stride=cfg.model.stride)
                            for i in range(bboxesGt.shape[0]):
                                x1, y1, w, h =  bboxesGt[i]
                                cls = classesGt[i]
                                plotBox(imgt, x1,y1,x1+w,y1+h,[cfg.clsname[cls]])
                                print("target[x1:%.2f y1:%.2f w:%.2f h:%.2f " % (x1, y1, w, h)+ cfg.clsname[cls]+"]")
                            print(annoName)
                            cv2.imshow("target", imgt)
                            cv2.waitKey()
                            print("-"*50)
    
                        # """post"""
                        # # target contain object mask
                        # coobjMask_ = pred[:, :, :, 4] > scoreThresh
                        # for i in range(1, cfg.model.bboxPredNum):
                        #     coobjMask_ = torch.logical_or(coobjMask_, pred[:, :, :, 4 + i * 5] > scoreThresh)
                        #
                        # #选择每个anchor 点 预测置信度最大的, bbox num = 1 代码也可以
                        # confPred = pred[:, :, :, 4:5 * cfg.model.bboxPredNum:5]
                        # _, maxIndex = torch.max(confPred, dim=-1, keepdim=True)
                        # b_, h_, w_, c_ = pred.size()
                        # oneHot = torch.zeros(b_, h_, w_, cfg.model.bboxPredNum).to("cuda:0").scatter_(-1, maxIndex, 1).type(
                        #     torch.bool)
                        #
                        # coobjMask = torch.logical_and(oneHot, torch.unsqueeze(coobjMask_, -1))
                        # chioceIndex = torch.nonzero(coobjMask, as_tuple=False)  #
                        #
                        # """ plot pred """
                        # chioceIndex = chioceIndex.to('cpu').numpy()
                        # dets = []
                        # imgp = np.copy(image)
                        # for i in range(chioceIndex.shape[0]):
                        #     imId, hid, wid, boxid = chioceIndex[i][0], chioceIndex[i][1], chioceIndex[i][2], chioceIndex[i][3]
                        #
                        #     fullVector = pred[imId][hid][wid].to("cpu").numpy()
                        #     clsVec = fullVector[-cfg.model.clsNum:]
                        #     boxVector = fullVector[boxid * 5: (boxid + 1) * 5]
                        #     deltax, deltay, w, h, score = boxVector[0], boxVector[1], boxVector[2], boxVector[3], boxVector[4]
                        #
                        #     cy = (hid + deltay) * cfg.model.stride
                        #     cx = (wid + deltax) * cfg.model.stride
                        #     w = w*w* cfg.model.netInput[1]
                        #     h = h*h * cfg.model.netInput[0]
                        #     c = clsVec.argmax()
                        #     dets.append([cx, cy, w, h, score, c])
                        # dets = np.array(dets)
                        # dets = dets.reshape(-1, 6)
                        # dets[:, :2] -= dets[:, 2:4] / 2
                        # dets[:, 2:4] += dets[:, :2]
                        # # dets = nms(dets, iouThresh)
    
                        if saveFlag and  modeDict[mode] == "valModeFlag":
                            np.savetxt(saveDir + infos["annoName"][bcid],dets)
                        if saveFlag and modeDict[mode] == "testModeFlag":
                            np.savetxt(saveDir + infos["imgName"][bcid].split["."][0]+".txt",dets)
  • 相关阅读:
    Building a Basic .NET Remoting Application
    Building a Basic .NET Remoting Application 之一 Building a Remotable Type
    Type Library Exporter (Tlbexp.exe)
    .NET Remoting Architecture
    Runtime Hosts
    Building a Basic .NET Remoting Application 之二 Building a Host Application
    Building a Basic .NET Remoting Application 之三 Building a Client Application
    C语言字符串大全(转自百度百科)
    作业 双向链表
    递推练习 简单n!
  • 原文地址:https://www.cnblogs.com/yunshangyue71/p/14628040.html
Copyright © 2011-2022 走看看