zoukankan      html  css  js  c++  java
  • 数据集train和test分类脚本,以nyu数据集为例

    # coding: UTF-8
    # Change the input of function classify to change the scale of train set and test set
    # Use this script in $project/tools
    import os
    
    
    def find_last(string, str):
        last_position = -1
        while True:
            position=string.find(str, last_position+1)
            if position == -1:
                return last_position
            last_position = position
    
    filePath = '../data/images/'
    pathDir = os.listdir(filePath)
    pathDir.sort()
    # print pathDir
    sceneSum = len(pathDir)
    
    sceneList = []
    for allDir in pathDir:
        index = find_last(allDir, '_')
        scene = allDir[0:index]
        if scene not in sceneList:
            sceneList.append(scene)
    sceneList.sort()
    sceneNum = len(sceneList)
    print("NYU data sets have %d scenes, they are: " % sceneNum)
    print(sceneList)
    
    
    def classify(train_test=0.6):
        eachScene = []
        for i in range(0, sceneNum):
            temp = 0
            sceneIndex = sceneList[i]
            for j in range(0, sceneSum):
                if sceneIndex in pathDir[j]:
                    completeDir = filePath + pathDir[j]
                    temp = temp + len(os.listdir(completeDir))
            eachScene.append(temp)
        print ('Each scenes has images:')
        print eachScene, 'Total in', sum(eachScene), 'images'
    
        txtTrain = open('../data/train.txt', 'w')
        txtTest = open('../data/test.txt', 'w')
        trainNum = 0
        testNum = 0
        for i in range(0, sceneNum):
            classifyTrain = int(train_test*eachScene[i])
            temp = 0
            sceneIndex = sceneList[i]
            for j in range(0, sceneSum):
                if sceneIndex in pathDir[j]:
                    completeDir = filePath + pathDir[j]
                    eachSceneSum = len(os.listdir(completeDir))
                    for k in range(0, eachSceneSum):
                        if temp < classifyTrain:
                            # print pathDir[j] + '/' + os.listdir(completeDir)[k]
                            writeLine = pathDir[j] + '/' + os.listdir(completeDir)[k] + '
    '
                            txtTrain.write(writeLine)
                            temp = temp + 1
                            trainNum = trainNum + 1
                        else:
                            writeLine = pathDir[j] + '/' + os.listdir(completeDir)[k] + '
    '
                            txtTest.write(writeLine)
                            temp = temp + 1
                            testNum = testNum + 1
        txtTrain.close()
        txtTest.close()
        print 'The sum images of train set is', trainNum
        print 'The sum images of test set is', testNum
    
    classify(0.6)

    本文为原创,转载需注明!

  • 相关阅读:
    SAP ABAP Netweaver服务器的标准登录方式讲解
    php导出百万数据到csv
    消息中间件Kafaka
    kafka安装
    Linux系统下安装jdk及环境配置(两种方法)
    PHP导出3w条数据成表格
    excel 导出导入
    利用Redis锁解决高并发问题
    BeyondCompare4破解方法
    Linux(Ubuntu)通过nfs挂载远程硬盘
  • 原文地址:https://www.cnblogs.com/roboai/p/7587435.html
Copyright © 2011-2022 走看看