zoukankan      html  css  js  c++  java
  • Jtester+unitils+testng:DAO单元测试文件模板自动生成


         定位

         本文适合于不愿意手工编写而想自动化生成DAO单元测试的筒鞋。成果是不能照搬的,但其中的"创建模板、填充内容、自动生成"思想是可以复用的。读完本文,可以了解 Python 读取配置文件、替换字符串相关的知识点。

         在使用 jtester+unitils+testng 做数据库接口的单元测试框架中, 常常需要编写一些 wiki 及 DAOTest java 文件,  比如:

    public class XXXDefaultDAOTest extends BaseRegionDbDAOTestCase {
    
        @SpringBeanByName
        private XXXDefaultDAO XXXDefaultDAO;
        
        @Test
        @DbFit(when="XXXDefaultDAOTest.initBlank.when.wiki", then="XXXDefaultDAOTest.queryOneRecord.then.wiki")
        public void testInsertXXXDefaultDO() {
            XXXDefaultDO XXXDefaultDO = new XXXDefaultDO();
            XXXDefaultDO.setId(1L);
            XXXDefaultDO.setCidrBlock("192.168.10.10");
            XXXDefaultDO.setIpProtocol("tcp");
            XXXDefaultDO.setPortRange("3000:4000");
            XXXDefaultDO.setPolicy(Policy.POLICY_ACCEPT);
            XXXDefaultDO.setNic(Nic.INTRANET);
            XXXDefaultDO.setPriority(65533L);
            XXXDefaultDO.setType(1L);
            XXXDefaultDO.setIsDeleted(0L);
            XXXDefaultDO.setDescription("test1");
            XXXDefaultDO.setGmtCreate(new Date());
            XXXDefaultDO.setGmtModify(new Date());
            XXXDefaultDAO.insertXXXDefaultDO(XXXDefaultDO);
        }
    
        @Test
        @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki")
        public void testCountXXXDefaultDOByExample() {
            XXXDefaultDO XXXDefaultDO = new XXXDefaultDO();
            Assert.assertTrue(XXXDefaultDAO.countXXXDefaultDOByExample(XXXDefaultDO) == 1);
        }
    
        @Test
        @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki", then="XXXDefaultDAOTest.testUpdate.then.wiki")
        public void testUpdateXXXDefaultDO() {
            XXXDefaultDO found = XXXDefaultDAO.findXXXDefaultDOByPrimaryKey(6L);
            found.setIpProtocol("udp");
            found.setNic(Nic.INTERNET);
            found.setDescription("desc");
            XXXDefaultDAO.updateXXXDefaultDO(found);
        }
    
        @Test
        @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki")
        public void testFindListByExample() {
            String cidrBlock = "10.152.126.83";
            Policy policy = Policy.POLICY_ACCEPT;
            XXXDefaultDO XXXDefault = new XXXDefaultDO();
            XXXDefault.setCidrBlock(cidrBlock);
            XXXDefault.setPolicy(policy);
            List<XXXDefaultDO> list = XXXDefaultDAO.findListByExample(XXXDefault);
            Assert.assertEquals(list.size(), 1);
            for (XXXDefaultDO XXXDefaultDO: list) {
                Assert.assertEquals(XXXDefaultDO.getCidrBlock(), cidrBlock);
                Assert.assertEquals(XXXDefaultDO.getPolicy(), policy);
            }
        }
    
        @Test
        @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki")
        public void testFindXXXDefaultDOByPrimaryKey() {
            XXXDefaultDO found = XXXDefaultDAO.findXXXDefaultDOByPrimaryKey(6L);
            Assert.assertEquals(found.getCidrBlock(), "10.152.126.83");
            Assert.assertEquals(found.getIpProtocol(), "all");
            Assert.assertEquals(found.getPortRange(), "");
            Assert.assertEquals(found.getPolicy(), Policy.POLICY_ACCEPT);
            Assert.assertEquals(found.getNic(), Nic.BOTH);
            Assert.assertEquals(found.getPriority().longValue(),1L);
            Assert.assertEquals(found.getType().intValue(), 1);
            Assert.assertEquals(found.getIsDeleted().intValue(), 0);
            Assert.assertEquals(found.getDescription(), "bie dong");
        }
    
        @Test
        @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki", then="")
        public void testDeleteXXXDefaultDOByPrimaryKey() {
            Integer count = XXXDefaultDAO.deleteXXXDefaultDOByPrimaryKey(6L);
            Assert.assertEquals(count.intValue(), 1);
            
            Integer nodelete = XXXDefaultDAO.deleteXXXDefaultDOByPrimaryKey(6L);
            Assert.assertEquals(nodelete.intValue(), 0);
        }
    
    }

           

           其中, 数据准备文件在 *.when.wiki 中, 数据验证文件在 *.then.wiki 中, 数据库中只需要保证正确的表结构即可。 每次单元测试都是自动化可重复的。

    XXXDefaultDAOTest.initBlank.when.wiki
    |connect|
    |clean table|xxx_default|
          
    XXXDefaultDAOTest.initRecords.when.wiki |connect| |clean table|xxx_default| |clean table|xxx| |insert|xxx_default| | id | gmt_create | gmt_modify | cidr_block | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description | | 6 | 2014-04-08 20:18:04 | 2014-04-08 20:18:04 | 10.152.126.83 | all | | accept | 3 | 1 | 1 | 0 | bie dong | XXXDefaultDAOTest.queryOneRecord.then.wiki |connect| |query|select cidr_block, ip_protocol, port_range, policy, nic, priority, type, is_deleted, description from xxx_default | |cidr_block | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description | |192.168.10.10 | tcp | 3000:4000 | accept | 2 | 65533 | 1 | 0 | test1 | XXXDefaultDAOTest.testUpdate.then.wiki |connect| |query|select cidr_block, ip_protocol, port_range, policy, nic, priority, type, is_deleted, description from xxx_default| | cidr_block | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description | | 10.152.126.83 | udp | | accept | 1 | 1 | 1 | 0 | desc |

             

      显然, 如果每个 DAO 测试类都写这些 WIKI  及 DAO 类(set/get 字段很耗体力), 那会是比较大的工作量。 这时候, 最好能够自动生成这些文件或文件模板, 减少手工的劳动量。

             因此, 编写了一个 python 程序, 在指定配置下, 可以自动生成相关的测试文件模板文件。

          readcfg.py :  读取DAO测试类信息的配置文件     

    from ConfigParser import ConfigParser
    
    config = ConfigParser()
    config.read("daotest.conf")
    
    def getAllDAOTestInfo():
        allDAOTest = {}
        secs = config.sections() 
        for sec in secs:
            allDAOTest[sec] = getDAOTestInfo(sec)
        return allDAOTest        
    
    def getDAOTestInfo(daoTestName) :
        daoTestInfo = { 
            'DaoTestName': config.get(daoTestName, 'DaoTestName') ,
            'TableName': config.get(daoTestName, 'TableName'), 
            'FieldArray': config.get(daoTestName, 'FieldArray'),
            'NumTypeFields': config.get(daoTestName, 'NumTypeFields'),
        }
        return daoTestInfo     

        create_daotest_wiki.py : 生成 dao 测试的测试文件模板:    

    import readcfg
    import time
    import re
    
    def gene_daotest(daoTestInfo):
    
        daoTestName = daoTestInfo['DaoTestName']
        tableName = daoTestInfo['TableName']
        fieldArray = re.split('s*,s*', daoTestInfo['FieldArray'])
        numTypeFields = set(re.split('s*,s*', daoTestInfo['NumTypeFields']))
        
        print ' *** ', daoTestName , ' start...
    '
        
        startTime = time.clock()
        gene_daotest_wiki_really(daoTestName, tableName, fieldArray, numTypeFields)
        gene_daotest_java(daoTestName, tableName, fieldArray, numTypeFields)
        endTime = time.clock()
        
        print ' *** ', daoTestName,  ' finished.
    '
        print 'time cost: ', str((endTime - startTime)*1000) + 'ms.
    '
    
    def gene_daotest_wiki_really(daoTestName, tableName, fieldArray, numTypeFields):
        
        '''
            generate the wikies used for DAO test java file
        '''
        
        conn = '|connect|'
        clean_table = '|clean table|' + tableName + '|'
        insert_table = '|insert|' + tableName + '|'
        all_fields = '|' + getfieldsWithSep(fieldArray, 0, '|') + '|'
        query_stmt = '|query|' + 'select ' + getfieldsWithSep(fieldArray, 0, ', ', filterTimeAndIdFieldFunc) + ' from ' + tableName + '|'
        query_fields = '|' + getfieldsWithSep(fieldArray, 0, '|', filterTimeAndIdFieldFunc) + '|'
        all_fields_default_values = '|' + getfieldValuesWithSep(fieldArray, numTypeFields, 0, '|') + '|'
        query_fields_default_values = '|' + getfieldValuesWithSep(fieldArray, numTypeFields, 0, '|', filterTimeAndIdFieldFunc) + '|'
        
        # create DaoTestName.initBlank.when.wiki
        f_initBlank = open(daoTestName+".initBlank.when.wiki", 'w')
        f_initBlank.write('
    '.join([conn, clean_table]));
        f_initBlank.close
        
        # create DaoTestName.initRecords.when.wiki
        f_initRecs = open(daoTestName+".initRecords.when.wiki", 'w')
        f_initRecs.write('
    '.join([conn, clean_table, insert_table, all_fields, all_fields_default_values]))
        f_initRecs.close
        
        # create DaoTestName.queryOneRecord.then.wiki
        f_qor = open(daoTestName+".queryOneRecord.then.wiki", 'w')
        f_qor.write('
    '.join([conn, query_stmt, query_fields, query_fields_default_values]))
        f_qor.close
        
        # create DaoTestName.testUpdate.then.wiki
        f_update = open(daoTestName+".testUpdate.then.wiki", 'w')
        f_update.write('
    '.join([conn, query_stmt, query_fields, query_fields_default_values]))
        f_update.close
        
    def gene_daotest_java(daoTestName, tableName, fieldArray, numTypeFields):
        
        f_daotest_java = open(daoTestName+'.java', 'w')
        f_daotest_tmpl = open('TemplateDefaultDAOTest.java')
        content = ''
        for line in f_daotest_tmpl:
            content += line
        daoPrefixIndex = daoTestName.find('DAOTest')
        daoPrefix = daoTestName[0: daoPrefixIndex]
        XXXReplacer = daoPrefix
        YYYReplacer = firstLowerCase(XXXReplacer)
        filteredFieldArray = getFilteredFields(fieldArray, filterTimeFieldFunc)
        contentReplaced = content.replace('XXX', XXXReplacer).replace('YYY', YYYReplacer)  
                                 .replace('$setFields', geneSetFields(filteredFieldArray, numTypeFields, YYYReplacer)) 
                                 .replace('$AssertGetValues', geneAssertGetValues(filteredFieldArray, numTypeFields, YYYReplacer))
        f_daotest_java.write(contentReplaced)
    
    
    def geneAssertGetValues(fieldArray, numTypeFields, YYYReplacer):
        content = ''
        for field in fieldArray:
            quoteStr = '' if field in numTypeFields else '"'
            content += 'Assert.assertEquals(%s.get%s(), %s%s%s);
    %s' %  
                       (YYYReplacer, transformField(field), quoteStr, getDefaultValueForField(field, numTypeFields), quoteStr, indentTimes(2)) 
        return content
        
    def geneSetFields(fieldArray, numTypeFields, YYYReplacer):
        content = ''
        for field in fieldArray:
            quoteStr = '' if field in numTypeFields else '"'
            content += '%s.set%s(%s%s%s);
    %s' % 
                      (YYYReplacer, transformField(field), quoteStr, getDefaultValueForField(field, numTypeFields), quoteStr, indentTimes(2))
        return content
        
    def transformField(field):
        '''
           convert field with UnderLine form to Camel Form
           eg.  cidr_block ==> CidrBlock        
        '''
        
        parts = field.split('_')
        content = ''
        for part in parts: 
            content += firstSuperCase(part)
        return content
    
    def indentTimes(num):
        indent = '';
        while num > 0 :
            indent += '	'
            num -= 1
        return indent
        
    def firstLowerCase(input):
        '''
            the first letter lowered. eg. NcDAOTest ==> ncDAOTest
        '''    
        return input[0].lower() + input[1:]
        
    def firstSuperCase(input):
        '''
            the first letter uppered. eg. ncDAOTest ==> NcDAOTest
        '''    
        return input[0].upper() + input[1:]    
        
    def nopFunc(field):
        return True    
    
    def currTime():
        return time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))        
        
    def getDefaultValueForField(field, numTypeFields):
        '''
            get the default value of field and return a value
            if you want to set more proper value , do it here.
        '''
        if field == 'id' or field.find('_id') != -1:
            return '1'
        elif field.find('gmt_') != -1:
            return currTime()
        elif field.find('ip') != -1 or field.find('addr') != -1:
            return '172.16.0.1'
        elif field.find('cidr') != -1:
            return '172.16.0.0/22'
        elif field in numTypeFields:
            return '0'
        else :
            return 'test-'+field
            
    def getfieldsWithSep(fieldArray, index=0, sep='|', filterFunc=nopFunc):
        if index < 0 or index > len(fieldArray):
            raise Exception('index '  + index + ' invalid: must be in [0,' + len(fieldArray) + ']')
        fieldFilteredArray = getFilteredFields(fieldArray, filterFunc)
        return sep.join(fieldFilteredArray[index:])    
        
    def getfieldValuesWithSep(fieldArray, numTypeFields, index=0, sep='|', filterFunc=nopFunc):
        if index < 0 or index > len(fieldArray):
            raise Exception('index '  + index + ' invalid: must be in [0,' + len(fieldArray) + ']')
        fieldFilteredArray = getFilteredFields(fieldArray, filterFunc)
        fieldDefaultValues = []
        for field in fieldFilteredArray:
            fieldDefaultValues.append(getDefaultValueForField(field, numTypeFields)) 
        return sep.join(fieldDefaultValues)
        
    def filterTimeAndIdFieldFunc(field):
        return field.find('gmt_') == -1 and field != 'id'
    
    def filterTimeFieldFunc(field):
        return field.find('gmt_') == -1 
        
    def getFilteredFields(fieldArray, filterFunc):
        return filter(filterFunc, fieldArray)
        
    if __name__ == '__main__':
        allDAOTest = readcfg.getAllDAOTestInfo()
        for daoTestName, daoTestInfo in allDAOTest.iteritems():
            gene_daotest(daoTestInfo)

         daotest.conf:  配置文件  

    [VmDAOTest]
    DaoTestName=VmDAOTest
    TableName=vm
    FieldArray=id,gmt_create,gmt_modify,vm_name,cores,mem,disk,status,nc_id,is_deleted
    NumTypeFields=id,cores,mem,disk,status,nc_id,is_deleted
    
    [NcDAOTest]
    DaoTestName=NcDAOTest
    TableName=nc
    FieldArray=id,gmt_create,gmt_modify,hostname,ip,avail_cpu, avail_mem, avail_disk
    NumTypeFields=id,avail_cpu, avail_mem, avail_disk

         DAO java 文件模板: 

    package xxx.dao.regiondb.impl;
    
    import java.util.Date;
    import java.util.List;
    
    import org.jtester.unitils.dbfit.DbFit;
    import org.testng.Assert;
    import org.testng.annotations.Test;
    import org.unitils.spring.annotation.SpringBeanByName;
    
    import xxx.BaseRegionDbDAOTestCase;
    import xxx.constant.group.Nic;
    import xxx.constant.group.Policy;
    import xxx.dao.regiondb.XXXDAO;
    import xxx.model.db.XXXDO;
    
    public class XXXDAOTest extends BaseRegionDbDAOTestCase {
    
        @SpringBeanByName
        private XXXDAO YYYDAO;
        
        @Test
        @DbFit(when="XXXDAOTest.initBlank.when.wiki", then="XXXDAOTest.queryOneRecord.then.wiki")
        public void testInsertXXXDO() {
            XXXDO YYY = new XXXDO();
            $setFields
            YYY.setGmtCreate(new Date());
            YYY.setGmtModify(new Date());
            YYYDAO.insertXXXDO(YYY);
        }
    
        @Test
        @DbFit(when="XXXDAOTest.initRecords.when.wiki")
        public void testCountXXXDOByExample() {
            XXXDO YYY = new XXXDO();
            Assert.assertTrue(YYYDAO.countXXXDOByExample(YYY).intValue() == 1);
        }
    
        @Test
        @DbFit(when="XXXDAOTest.initRecords.when.wiki", then="XXXDAOTest.testUpdate.then.wiki")
        public void testUpdateXXXDO() {
            XXXDO YYY = YYYDAO.findXXXDOByPrimaryKey();
            $setFields
            YYYDAO.updateXXXDO(YYY);
        }
    
        @Test
        @DbFit(when="XXXDAOTest.initRecords.when.wiki")
        public void testFindListByExample() {
            XXXDO YYY = new XXXDO();
            $setFields
            List<XXXDO> list = YYYDAO.findListByExample(YYY);
            Assert.assertEquals(list.size(), 1);
            for (XXXDO YYYDO: list) {
                $AssertGetValues
            }
        }
    
        @Test
        @DbFit(when="XXXDAOTest.initRecords.when.wiki")
        public void testFindXXXDOByPrimaryKey() {
            XXXDO YYY = YYYDAO.findXXXDOByPrimaryKey(1L);
            $AssertGetValues
        }
    
        @Test
        @DbFit(when="XXXDAOTest.initRecords.when.wiki")
        public void testDeleteXXXDOByPrimaryKey() {
            Integer count = YYYDAO.deleteXXXDOByPrimaryKey(1L);
            Assert.assertEquals(count.intValue(), 1);
            
            Integer nodelete = YYYDAO.deleteXXXDOByPrimaryKey(1L);
            Assert.assertEquals(nodelete.intValue(), 0);
        }
    
    }

         运行: $ python create_daotest_wiki.py

         生成以下文件: 

         

          其中: 
         

    VmDAOTest.initBlank.when.wiki  
    
    |connect|
    |clean table|vm|
     
    
    VmDAOTest.initRecords.when.wiki    
    
    |connect|
    |clean table|vm|
    |insert|vm|
    |id|gmt_create|gmt_modify|vm_name|cores|mem|disk|status|nc_id|is_deleted|
    |1|2014-05-22 12:51:38|2014-05-22 12:51:38|test-vm_name|0|0|0|0|1|0|
     
    
     VmDAOTest.queryOneRecord.when.wiki / VmDAOTest.testUpdate.when.wiki       
    
    |connect|
    |query|select vm_name, cores, mem, disk, status, nc_id, is_deleted from vm|
    |vm_name|cores|mem|disk|status|nc_id|is_deleted|
    |test-vm_name|0|0|0|0|1|0| 

          生成的DAOTEST Java 文件:  

    package xxx.dao.regiondb.impl;
    
    import java.util.Date;
    import java.util.List;
    
    import org.jtester.unitils.dbfit.DbFit;
    import org.testng.Assert;
    import org.testng.annotations.Test;
    import org.unitils.spring.annotation.SpringBeanByName;
    
    import xxx.BaseRegionDbDAOTestCase;
    import xxx.constant.group.Nic;
    import xxx.constant.group.Policy;
    import xxx.dao.regiondb.VmDAO;
    import xxx.model.db.VmDO;
    
    public class VmDAOTest extends BaseRegionDbDAOTestCase {
    
        @SpringBeanByName
        private VmDAO vmDAO;
        
        @Test
        @DbFit(when="VmDAOTest.initBlank.when.wiki", then="VmDAOTest.queryOneRecord.then.wiki")
        public void testInsertVmDO() {
            VmDO vm = new VmDO();
            vm.setId(1);
            vm.setVmName("test-vm_name");
            vm.setCores(0);
            vm.setMem(0);
            vm.setDisk(0);
            vm.setStatus(0);
            vm.setNcId(1);
            vm.setIsDeleted(0);
            
            vm.setGmtCreate(new Date());
            vm.setGmtModify(new Date());
            vmDAO.insertVmDO(vm);
        }
    
        @Test
        @DbFit(when="VmDAOTest.initRecords.when.wiki")
        public void testCountVmDOByExample() {
            VmDO vm = new VmDO();
            Assert.assertTrue(vmDAO.countVmDOByExample(vm).intValue() == 1);
        }
    
        @Test
        @DbFit(when="VmDAOTest.initRecords.when.wiki", then="VmDAOTest.testUpdate.then.wiki")
        public void testUpdateVmDO() {
            VmDO vm = vmDAO.findVmDOByPrimaryKey();
            vm.setId(1);
            vm.setVmName("test-vm_name");
            vm.setCores(0);
            vm.setMem(0);
            vm.setDisk(0);
            vm.setStatus(0);
            vm.setNcId(1);
            vm.setIsDeleted(0);
            
            vmDAO.updateVmDO(vm);
        }
    
        @Test
        @DbFit(when="VmDAOTest.initRecords.when.wiki")
        public void testFindListByExample() {
            VmDO vm = new VmDO();
            vm.setId(1);
            vm.setVmName("test-vm_name");
            vm.setCores(0);
            vm.setMem(0);
            vm.setDisk(0);
            vm.setStatus(0);
            vm.setNcId(1);
            vm.setIsDeleted(0);
            
            List<VmDO> list = vmDAO.findListByExample(vm);
            Assert.assertEquals(list.size(), 1);
            for (VmDO vmDO: list) {
                Assert.assertEquals(vm.getId(), 1)
            Assert.assertEquals(vm.getVmName(), "test-vm_name")
            Assert.assertEquals(vm.getCores(), 0)
            Assert.assertEquals(vm.getMem(), 0)
            Assert.assertEquals(vm.getDisk(), 0)
            Assert.assertEquals(vm.getStatus(), 0)
            Assert.assertEquals(vm.getNcId(), 1)
            Assert.assertEquals(vm.getIsDeleted(), 0)
            
            }
        }
    
        @Test
        @DbFit(when="VmDAOTest.initRecords.when.wiki")
        public void testFindVmDOByPrimaryKey() {
            VmDO vm = vmDAO.findVmDOByPrimaryKey(1L);
            Assert.assertEquals(vm.getId(), 1)
            Assert.assertEquals(vm.getVmName(), "test-vm_name")
            Assert.assertEquals(vm.getCores(), 0)
            Assert.assertEquals(vm.getMem(), 0)
            Assert.assertEquals(vm.getDisk(), 0)
            Assert.assertEquals(vm.getStatus(), 0)
            Assert.assertEquals(vm.getNcId(), 1)
            Assert.assertEquals(vm.getIsDeleted(), 0)
            
        }
    
        @Test
        @DbFit(when="VmDAOTest.initRecords.when.wiki")
        public void testDeleteVmDOByPrimaryKey() {
            Integer count = vmDAO.deleteVmDOByPrimaryKey(1L);
            Assert.assertEquals(count.intValue(), 1);
            
            Integer nodelete = vmDAO.deleteVmDOByPrimaryKey(1L);
            Assert.assertEquals(nodelete.intValue(), 0);
        }
    
    }

         结语:

         只要是手工劳动, 尽可能自动化。而要做到自动化, 第一是规范标准化, 第二是要发现一些规律性的模式。 

     

  • 相关阅读:
    [JOYOI1326] 剑人合一
    linux hive +mysql(mysql用于hive元数据存储)
    hadoop 伪分布式单机部署练习hive
    pyhton 操作hive数据仓库
    python操作hadoop HDFS api使用
    hadoop伪集群部署
    python 文件指针切割文件
    jdk8 permgen OOM再见迎来metaspace
    java JVM内存区域模型
    java垃圾回收
  • 原文地址:https://www.cnblogs.com/lovesqcc/p/4037695.html
Copyright © 2011-2022 走看看