zoukankan      html  css  js  c++  java
  • 数据集train和test分类脚本,以nyu数据集为例

    # coding: UTF-8
    # Change the input of function classify to change the scale of train set and test set
    # Use this script in $project/tools
    import os
    
    
    def find_last(string, str):
        last_position = -1
        while True:
            position=string.find(str, last_position+1)
            if position == -1:
                return last_position
            last_position = position
    
    filePath = '../data/images/'
    pathDir = os.listdir(filePath)
    pathDir.sort()
    # print pathDir
    sceneSum = len(pathDir)
    
    sceneList = []
    for allDir in pathDir:
        index = find_last(allDir, '_')
        scene = allDir[0:index]
        if scene not in sceneList:
            sceneList.append(scene)
    sceneList.sort()
    sceneNum = len(sceneList)
    print("NYU data sets have %d scenes, they are: " % sceneNum)
    print(sceneList)
    
    
    def classify(train_test=0.6):
        eachScene = []
        for i in range(0, sceneNum):
            temp = 0
            sceneIndex = sceneList[i]
            for j in range(0, sceneSum):
                if sceneIndex in pathDir[j]:
                    completeDir = filePath + pathDir[j]
                    temp = temp + len(os.listdir(completeDir))
            eachScene.append(temp)
        print ('Each scenes has images:')
        print eachScene, 'Total in', sum(eachScene), 'images'
    
        txtTrain = open('../data/train.txt', 'w')
        txtTest = open('../data/test.txt', 'w')
        trainNum = 0
        testNum = 0
        for i in range(0, sceneNum):
            classifyTrain = int(train_test*eachScene[i])
            temp = 0
            sceneIndex = sceneList[i]
            for j in range(0, sceneSum):
                if sceneIndex in pathDir[j]:
                    completeDir = filePath + pathDir[j]
                    eachSceneSum = len(os.listdir(completeDir))
                    for k in range(0, eachSceneSum):
                        if temp < classifyTrain:
                            # print pathDir[j] + '/' + os.listdir(completeDir)[k]
                            writeLine = pathDir[j] + '/' + os.listdir(completeDir)[k] + '
    '
                            txtTrain.write(writeLine)
                            temp = temp + 1
                            trainNum = trainNum + 1
                        else:
                            writeLine = pathDir[j] + '/' + os.listdir(completeDir)[k] + '
    '
                            txtTest.write(writeLine)
                            temp = temp + 1
                            testNum = testNum + 1
        txtTrain.close()
        txtTest.close()
        print 'The sum images of train set is', trainNum
        print 'The sum images of test set is', testNum
    
    classify(0.6)

    本文为原创,转载需注明!

  • 相关阅读:
    VMWare虚拟机下为Ubuntu 12.04.1配置静态IP(NAT方式)
    VMWare虚拟机下为Windows Server 2012配置静态IP(NAT方式)
    Windows 7防火墙阻止了远程桌面连接的解决方法
    Win10系统如何在防火墙里开放端口
    ECharts 定制 label 样式
    目标值柱状图
    echarts中datazoom相关配置
    环形图
    带时间轴的指标监控柱状图
    2020mysql面试题
  • 原文地址:https://www.cnblogs.com/roboai/p/7587435.html
Copyright © 2011-2022 走看看