定位
本文适合于不愿意手工编写而想自动化生成DAO单元测试的筒鞋。成果是不能照搬的,但其中的"创建模板、填充内容、自动生成"思想是可以复用的。读完本文,可以了解 Python 读取配置文件、替换字符串相关的知识点。
在使用 jtester+unitils+testng 做数据库接口的单元测试框架中, 常常需要编写一些 wiki 及 DAOTest java 文件, 比如:
public class XXXDefaultDAOTest extends BaseRegionDbDAOTestCase { @SpringBeanByName private XXXDefaultDAO XXXDefaultDAO; @Test @DbFit(when="XXXDefaultDAOTest.initBlank.when.wiki", then="XXXDefaultDAOTest.queryOneRecord.then.wiki") public void testInsertXXXDefaultDO() { XXXDefaultDO XXXDefaultDO = new XXXDefaultDO(); XXXDefaultDO.setId(1L); XXXDefaultDO.setCidrBlock("192.168.10.10"); XXXDefaultDO.setIpProtocol("tcp"); XXXDefaultDO.setPortRange("3000:4000"); XXXDefaultDO.setPolicy(Policy.POLICY_ACCEPT); XXXDefaultDO.setNic(Nic.INTRANET); XXXDefaultDO.setPriority(65533L); XXXDefaultDO.setType(1L); XXXDefaultDO.setIsDeleted(0L); XXXDefaultDO.setDescription("test1"); XXXDefaultDO.setGmtCreate(new Date()); XXXDefaultDO.setGmtModify(new Date()); XXXDefaultDAO.insertXXXDefaultDO(XXXDefaultDO); } @Test @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki") public void testCountXXXDefaultDOByExample() { XXXDefaultDO XXXDefaultDO = new XXXDefaultDO(); Assert.assertTrue(XXXDefaultDAO.countXXXDefaultDOByExample(XXXDefaultDO) == 1); } @Test @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki", then="XXXDefaultDAOTest.testUpdate.then.wiki") public void testUpdateXXXDefaultDO() { XXXDefaultDO found = XXXDefaultDAO.findXXXDefaultDOByPrimaryKey(6L); found.setIpProtocol("udp"); found.setNic(Nic.INTERNET); found.setDescription("desc"); XXXDefaultDAO.updateXXXDefaultDO(found); } @Test @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki") public void testFindListByExample() { String cidrBlock = "10.152.126.83"; Policy policy = Policy.POLICY_ACCEPT; XXXDefaultDO XXXDefault = new XXXDefaultDO(); XXXDefault.setCidrBlock(cidrBlock); XXXDefault.setPolicy(policy); List<XXXDefaultDO> list = XXXDefaultDAO.findListByExample(XXXDefault); Assert.assertEquals(list.size(), 1); for (XXXDefaultDO XXXDefaultDO: list) { Assert.assertEquals(XXXDefaultDO.getCidrBlock(), cidrBlock); Assert.assertEquals(XXXDefaultDO.getPolicy(), policy); } } @Test @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki") public void testFindXXXDefaultDOByPrimaryKey() { XXXDefaultDO found = XXXDefaultDAO.findXXXDefaultDOByPrimaryKey(6L); Assert.assertEquals(found.getCidrBlock(), "10.152.126.83"); Assert.assertEquals(found.getIpProtocol(), "all"); Assert.assertEquals(found.getPortRange(), ""); Assert.assertEquals(found.getPolicy(), Policy.POLICY_ACCEPT); Assert.assertEquals(found.getNic(), Nic.BOTH); Assert.assertEquals(found.getPriority().longValue(),1L); Assert.assertEquals(found.getType().intValue(), 1); Assert.assertEquals(found.getIsDeleted().intValue(), 0); Assert.assertEquals(found.getDescription(), "bie dong"); } @Test @DbFit(when="XXXDefaultDAOTest.initRecords.when.wiki", then="") public void testDeleteXXXDefaultDOByPrimaryKey() { Integer count = XXXDefaultDAO.deleteXXXDefaultDOByPrimaryKey(6L); Assert.assertEquals(count.intValue(), 1); Integer nodelete = XXXDefaultDAO.deleteXXXDefaultDOByPrimaryKey(6L); Assert.assertEquals(nodelete.intValue(), 0); } }
其中, 数据准备文件在 *.when.wiki 中, 数据验证文件在 *.then.wiki 中, 数据库中只需要保证正确的表结构即可。 每次单元测试都是自动化可重复的。
XXXDefaultDAOTest.initBlank.when.wiki |connect| |clean table|xxx_default|
XXXDefaultDAOTest.initRecords.when.wiki |connect| |clean table|xxx_default| |clean table|xxx| |insert|xxx_default| | id | gmt_create | gmt_modify | cidr_block | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description | | 6 | 2014-04-08 20:18:04 | 2014-04-08 20:18:04 | 10.152.126.83 | all | | accept | 3 | 1 | 1 | 0 | bie dong | XXXDefaultDAOTest.queryOneRecord.then.wiki |connect| |query|select cidr_block, ip_protocol, port_range, policy, nic, priority, type, is_deleted, description from xxx_default | |cidr_block | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description | |192.168.10.10 | tcp | 3000:4000 | accept | 2 | 65533 | 1 | 0 | test1 | XXXDefaultDAOTest.testUpdate.then.wiki |connect| |query|select cidr_block, ip_protocol, port_range, policy, nic, priority, type, is_deleted, description from xxx_default| | cidr_block | ip_protocol | port_range | policy | nic | priority | type | is_deleted | description | | 10.152.126.83 | udp | | accept | 1 | 1 | 1 | 0 | desc |
显然, 如果每个 DAO 测试类都写这些 WIKI 及 DAO 类(set/get 字段很耗体力), 那会是比较大的工作量。 这时候, 最好能够自动生成这些文件或文件模板, 减少手工的劳动量。
因此, 编写了一个 python 程序, 在指定配置下, 可以自动生成相关的测试文件模板文件。
readcfg.py : 读取DAO测试类信息的配置文件
from ConfigParser import ConfigParser config = ConfigParser() config.read("daotest.conf") def getAllDAOTestInfo(): allDAOTest = {} secs = config.sections() for sec in secs: allDAOTest[sec] = getDAOTestInfo(sec) return allDAOTest def getDAOTestInfo(daoTestName) : daoTestInfo = { 'DaoTestName': config.get(daoTestName, 'DaoTestName') , 'TableName': config.get(daoTestName, 'TableName'), 'FieldArray': config.get(daoTestName, 'FieldArray'), 'NumTypeFields': config.get(daoTestName, 'NumTypeFields'), } return daoTestInfo
create_daotest_wiki.py : 生成 dao 测试的测试文件模板:
import readcfg import time import re def gene_daotest(daoTestInfo): daoTestName = daoTestInfo['DaoTestName'] tableName = daoTestInfo['TableName'] fieldArray = re.split('s*,s*', daoTestInfo['FieldArray']) numTypeFields = set(re.split('s*,s*', daoTestInfo['NumTypeFields'])) print ' *** ', daoTestName , ' start... ' startTime = time.clock() gene_daotest_wiki_really(daoTestName, tableName, fieldArray, numTypeFields) gene_daotest_java(daoTestName, tableName, fieldArray, numTypeFields) endTime = time.clock() print ' *** ', daoTestName, ' finished. ' print 'time cost: ', str((endTime - startTime)*1000) + 'ms. ' def gene_daotest_wiki_really(daoTestName, tableName, fieldArray, numTypeFields): ''' generate the wikies used for DAO test java file ''' conn = '|connect|' clean_table = '|clean table|' + tableName + '|' insert_table = '|insert|' + tableName + '|' all_fields = '|' + getfieldsWithSep(fieldArray, 0, '|') + '|' query_stmt = '|query|' + 'select ' + getfieldsWithSep(fieldArray, 0, ', ', filterTimeAndIdFieldFunc) + ' from ' + tableName + '|' query_fields = '|' + getfieldsWithSep(fieldArray, 0, '|', filterTimeAndIdFieldFunc) + '|' all_fields_default_values = '|' + getfieldValuesWithSep(fieldArray, numTypeFields, 0, '|') + '|' query_fields_default_values = '|' + getfieldValuesWithSep(fieldArray, numTypeFields, 0, '|', filterTimeAndIdFieldFunc) + '|' # create DaoTestName.initBlank.when.wiki f_initBlank = open(daoTestName+".initBlank.when.wiki", 'w') f_initBlank.write(' '.join([conn, clean_table])); f_initBlank.close # create DaoTestName.initRecords.when.wiki f_initRecs = open(daoTestName+".initRecords.when.wiki", 'w') f_initRecs.write(' '.join([conn, clean_table, insert_table, all_fields, all_fields_default_values])) f_initRecs.close # create DaoTestName.queryOneRecord.then.wiki f_qor = open(daoTestName+".queryOneRecord.then.wiki", 'w') f_qor.write(' '.join([conn, query_stmt, query_fields, query_fields_default_values])) f_qor.close # create DaoTestName.testUpdate.then.wiki f_update = open(daoTestName+".testUpdate.then.wiki", 'w') f_update.write(' '.join([conn, query_stmt, query_fields, query_fields_default_values])) f_update.close def gene_daotest_java(daoTestName, tableName, fieldArray, numTypeFields): f_daotest_java = open(daoTestName+'.java', 'w') f_daotest_tmpl = open('TemplateDefaultDAOTest.java') content = '' for line in f_daotest_tmpl: content += line daoPrefixIndex = daoTestName.find('DAOTest') daoPrefix = daoTestName[0: daoPrefixIndex] XXXReplacer = daoPrefix YYYReplacer = firstLowerCase(XXXReplacer) filteredFieldArray = getFilteredFields(fieldArray, filterTimeFieldFunc) contentReplaced = content.replace('XXX', XXXReplacer).replace('YYY', YYYReplacer) .replace('$setFields', geneSetFields(filteredFieldArray, numTypeFields, YYYReplacer)) .replace('$AssertGetValues', geneAssertGetValues(filteredFieldArray, numTypeFields, YYYReplacer)) f_daotest_java.write(contentReplaced) def geneAssertGetValues(fieldArray, numTypeFields, YYYReplacer): content = '' for field in fieldArray: quoteStr = '' if field in numTypeFields else '"' content += 'Assert.assertEquals(%s.get%s(), %s%s%s); %s' % (YYYReplacer, transformField(field), quoteStr, getDefaultValueForField(field, numTypeFields), quoteStr, indentTimes(2)) return content def geneSetFields(fieldArray, numTypeFields, YYYReplacer): content = '' for field in fieldArray: quoteStr = '' if field in numTypeFields else '"' content += '%s.set%s(%s%s%s); %s' % (YYYReplacer, transformField(field), quoteStr, getDefaultValueForField(field, numTypeFields), quoteStr, indentTimes(2)) return content def transformField(field): ''' convert field with UnderLine form to Camel Form eg. cidr_block ==> CidrBlock ''' parts = field.split('_') content = '' for part in parts: content += firstSuperCase(part) return content def indentTimes(num): indent = ''; while num > 0 : indent += ' ' num -= 1 return indent def firstLowerCase(input): ''' the first letter lowered. eg. NcDAOTest ==> ncDAOTest ''' return input[0].lower() + input[1:] def firstSuperCase(input): ''' the first letter uppered. eg. ncDAOTest ==> NcDAOTest ''' return input[0].upper() + input[1:] def nopFunc(field): return True def currTime(): return time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) def getDefaultValueForField(field, numTypeFields): ''' get the default value of field and return a value if you want to set more proper value , do it here. ''' if field == 'id' or field.find('_id') != -1: return '1' elif field.find('gmt_') != -1: return currTime() elif field.find('ip') != -1 or field.find('addr') != -1: return '172.16.0.1' elif field.find('cidr') != -1: return '172.16.0.0/22' elif field in numTypeFields: return '0' else : return 'test-'+field def getfieldsWithSep(fieldArray, index=0, sep='|', filterFunc=nopFunc): if index < 0 or index > len(fieldArray): raise Exception('index ' + index + ' invalid: must be in [0,' + len(fieldArray) + ']') fieldFilteredArray = getFilteredFields(fieldArray, filterFunc) return sep.join(fieldFilteredArray[index:]) def getfieldValuesWithSep(fieldArray, numTypeFields, index=0, sep='|', filterFunc=nopFunc): if index < 0 or index > len(fieldArray): raise Exception('index ' + index + ' invalid: must be in [0,' + len(fieldArray) + ']') fieldFilteredArray = getFilteredFields(fieldArray, filterFunc) fieldDefaultValues = [] for field in fieldFilteredArray: fieldDefaultValues.append(getDefaultValueForField(field, numTypeFields)) return sep.join(fieldDefaultValues) def filterTimeAndIdFieldFunc(field): return field.find('gmt_') == -1 and field != 'id' def filterTimeFieldFunc(field): return field.find('gmt_') == -1 def getFilteredFields(fieldArray, filterFunc): return filter(filterFunc, fieldArray) if __name__ == '__main__': allDAOTest = readcfg.getAllDAOTestInfo() for daoTestName, daoTestInfo in allDAOTest.iteritems(): gene_daotest(daoTestInfo)
daotest.conf: 配置文件
[VmDAOTest] DaoTestName=VmDAOTest TableName=vm FieldArray=id,gmt_create,gmt_modify,vm_name,cores,mem,disk,status,nc_id,is_deleted NumTypeFields=id,cores,mem,disk,status,nc_id,is_deleted [NcDAOTest] DaoTestName=NcDAOTest TableName=nc FieldArray=id,gmt_create,gmt_modify,hostname,ip,avail_cpu, avail_mem, avail_disk NumTypeFields=id,avail_cpu, avail_mem, avail_disk
DAO java 文件模板:
package xxx.dao.regiondb.impl; import java.util.Date; import java.util.List; import org.jtester.unitils.dbfit.DbFit; import org.testng.Assert; import org.testng.annotations.Test; import org.unitils.spring.annotation.SpringBeanByName; import xxx.BaseRegionDbDAOTestCase; import xxx.constant.group.Nic; import xxx.constant.group.Policy; import xxx.dao.regiondb.XXXDAO; import xxx.model.db.XXXDO; public class XXXDAOTest extends BaseRegionDbDAOTestCase { @SpringBeanByName private XXXDAO YYYDAO; @Test @DbFit(when="XXXDAOTest.initBlank.when.wiki", then="XXXDAOTest.queryOneRecord.then.wiki") public void testInsertXXXDO() { XXXDO YYY = new XXXDO(); $setFields YYY.setGmtCreate(new Date()); YYY.setGmtModify(new Date()); YYYDAO.insertXXXDO(YYY); } @Test @DbFit(when="XXXDAOTest.initRecords.when.wiki") public void testCountXXXDOByExample() { XXXDO YYY = new XXXDO(); Assert.assertTrue(YYYDAO.countXXXDOByExample(YYY).intValue() == 1); } @Test @DbFit(when="XXXDAOTest.initRecords.when.wiki", then="XXXDAOTest.testUpdate.then.wiki") public void testUpdateXXXDO() { XXXDO YYY = YYYDAO.findXXXDOByPrimaryKey(); $setFields YYYDAO.updateXXXDO(YYY); } @Test @DbFit(when="XXXDAOTest.initRecords.when.wiki") public void testFindListByExample() { XXXDO YYY = new XXXDO(); $setFields List<XXXDO> list = YYYDAO.findListByExample(YYY); Assert.assertEquals(list.size(), 1); for (XXXDO YYYDO: list) { $AssertGetValues } } @Test @DbFit(when="XXXDAOTest.initRecords.when.wiki") public void testFindXXXDOByPrimaryKey() { XXXDO YYY = YYYDAO.findXXXDOByPrimaryKey(1L); $AssertGetValues } @Test @DbFit(when="XXXDAOTest.initRecords.when.wiki") public void testDeleteXXXDOByPrimaryKey() { Integer count = YYYDAO.deleteXXXDOByPrimaryKey(1L); Assert.assertEquals(count.intValue(), 1); Integer nodelete = YYYDAO.deleteXXXDOByPrimaryKey(1L); Assert.assertEquals(nodelete.intValue(), 0); } }
运行: $ python create_daotest_wiki.py
生成以下文件:
其中:
VmDAOTest.initBlank.when.wiki
|connect|
|clean table|vm|
VmDAOTest.initRecords.when.wiki
|connect|
|clean table|vm|
|insert|vm|
|id|gmt_create|gmt_modify|vm_name|cores|mem|disk|status|nc_id|is_deleted|
|1|2014-05-22 12:51:38|2014-05-22 12:51:38|test-vm_name|0|0|0|0|1|0|
VmDAOTest.queryOneRecord.when.wiki / VmDAOTest.testUpdate.when.wiki
|connect|
|query|select vm_name, cores, mem, disk, status, nc_id, is_deleted from vm|
|vm_name|cores|mem|disk|status|nc_id|is_deleted|
|test-vm_name|0|0|0|0|1|0|
生成的DAOTEST Java 文件:
package xxx.dao.regiondb.impl; import java.util.Date; import java.util.List; import org.jtester.unitils.dbfit.DbFit; import org.testng.Assert; import org.testng.annotations.Test; import org.unitils.spring.annotation.SpringBeanByName; import xxx.BaseRegionDbDAOTestCase; import xxx.constant.group.Nic; import xxx.constant.group.Policy; import xxx.dao.regiondb.VmDAO; import xxx.model.db.VmDO; public class VmDAOTest extends BaseRegionDbDAOTestCase { @SpringBeanByName private VmDAO vmDAO; @Test @DbFit(when="VmDAOTest.initBlank.when.wiki", then="VmDAOTest.queryOneRecord.then.wiki") public void testInsertVmDO() { VmDO vm = new VmDO(); vm.setId(1); vm.setVmName("test-vm_name"); vm.setCores(0); vm.setMem(0); vm.setDisk(0); vm.setStatus(0); vm.setNcId(1); vm.setIsDeleted(0); vm.setGmtCreate(new Date()); vm.setGmtModify(new Date()); vmDAO.insertVmDO(vm); } @Test @DbFit(when="VmDAOTest.initRecords.when.wiki") public void testCountVmDOByExample() { VmDO vm = new VmDO(); Assert.assertTrue(vmDAO.countVmDOByExample(vm).intValue() == 1); } @Test @DbFit(when="VmDAOTest.initRecords.when.wiki", then="VmDAOTest.testUpdate.then.wiki") public void testUpdateVmDO() { VmDO vm = vmDAO.findVmDOByPrimaryKey(); vm.setId(1); vm.setVmName("test-vm_name"); vm.setCores(0); vm.setMem(0); vm.setDisk(0); vm.setStatus(0); vm.setNcId(1); vm.setIsDeleted(0); vmDAO.updateVmDO(vm); } @Test @DbFit(when="VmDAOTest.initRecords.when.wiki") public void testFindListByExample() { VmDO vm = new VmDO(); vm.setId(1); vm.setVmName("test-vm_name"); vm.setCores(0); vm.setMem(0); vm.setDisk(0); vm.setStatus(0); vm.setNcId(1); vm.setIsDeleted(0); List<VmDO> list = vmDAO.findListByExample(vm); Assert.assertEquals(list.size(), 1); for (VmDO vmDO: list) { Assert.assertEquals(vm.getId(), 1) Assert.assertEquals(vm.getVmName(), "test-vm_name") Assert.assertEquals(vm.getCores(), 0) Assert.assertEquals(vm.getMem(), 0) Assert.assertEquals(vm.getDisk(), 0) Assert.assertEquals(vm.getStatus(), 0) Assert.assertEquals(vm.getNcId(), 1) Assert.assertEquals(vm.getIsDeleted(), 0) } } @Test @DbFit(when="VmDAOTest.initRecords.when.wiki") public void testFindVmDOByPrimaryKey() { VmDO vm = vmDAO.findVmDOByPrimaryKey(1L); Assert.assertEquals(vm.getId(), 1) Assert.assertEquals(vm.getVmName(), "test-vm_name") Assert.assertEquals(vm.getCores(), 0) Assert.assertEquals(vm.getMem(), 0) Assert.assertEquals(vm.getDisk(), 0) Assert.assertEquals(vm.getStatus(), 0) Assert.assertEquals(vm.getNcId(), 1) Assert.assertEquals(vm.getIsDeleted(), 0) } @Test @DbFit(when="VmDAOTest.initRecords.when.wiki") public void testDeleteVmDOByPrimaryKey() { Integer count = vmDAO.deleteVmDOByPrimaryKey(1L); Assert.assertEquals(count.intValue(), 1); Integer nodelete = vmDAO.deleteVmDOByPrimaryKey(1L); Assert.assertEquals(nodelete.intValue(), 0); } }
结语:
只要是手工劳动, 尽可能自动化。而要做到自动化, 第一是规范标准化, 第二是要发现一些规律性的模式。