zoukankan      html  css  js  c++  java
  • Unittest二次开发实战-02-定制TestResult类支持多线程并发捕获用例输出

    需求及问题

    在框架中,捕获并记录用例运行时的print输出信息是非常重要的。如下图。
    image.png
    在原始的unittest.TextTestResult中,这些输出信息是直接输出终端的,当然也将全部的运行结果输出到文本文件中。
    如果需要分别捕获不同用例的输出信息,一般采用重定向系统输入流的方式来实现,如下图。
    image.png
    每个用例在开始执行前将系统标准输出流重定向到一个自定义的io.String()流中,这样用例中的print等信息就会输出到此流中。在用例结束时获取该流中的数据,并还原系统标准输出流。

    这种模式,用例在执行期间需要独占系统输出流。在用例多线程并发时(多进程并发无此问题),使用同一系统输出流,便会导致不同用例的输出捕获混乱。如果对线程加锁则相当于串行,并发便没有意义。

    解决思路

    使用自定义的io流,分线程存储结果,用例运行时记录下其线程id,执行结束时重全部的io流中取出用例线程id对应的输出信息,如下图。
    image.png

    实现步骤

    1. 自定义io流按线程存储输出的信息
    import io
    import threading
    
    class ThreadStringIO(io.StringIO):
        """按线程存取结果"""
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.values = defaultdict(str)  # 按线程号存储
    
        def write(self, text):
            # log.info('线程输出', text)
            thread_id = str(threading.current_thread().ident)
            super().write(text)
            self.values[thread_id] += text
    
        def writelines(self, lines):
            super().writelines(lines)
            thread_id = str(threading.current_thread().ident)
            text = ''.join(lines)
            self.values[thread_id] += text
    
        def getvalues(self):
            """返回线程结果字典"""
            return self.values
    
    1. 重写父类_setupStdout中重定向的io流
        def _setupStdout(self):
            """分线程存储"""
    
            if self.buffer:
                if self._stderr_buffer is None:
                    # self._stderr_buffer = io.StringIO()
                    # self._stdout_buffer = io.StringIO()
                    self._stderr_buffer = ThreadStringIO()
                    self._stdout_buffer = ThreadStringIO()
                sys.stdout = self._stdout_buffer
                sys.stderr = self._stderr_buffer
    
    1. 重写_restoreStdout方法并返回按线程存储的io流数据
        def _restoreStdout(self):
            """重写父类的_restoreStdout方法"""
            if self.buffer:
                # if self._mirrorOutput:
                #     output = sys.stdout.getvalue()
                #     error = sys.stderr.getvalue()
                outputs = self._stdout_buffer.getvalues()
                errors = self._stderr_buffer.getvalues()
                
                sys.stdout = self._original_stdout
                sys.stdout = self._original_stderr
                self._stdout_buffer.seek(0)
                self._stdout_buffer.truncate()
                self._stderr_buffer.seek(0)
                self._stderr_buffer.truncate()
                return outputs, errors
    
    1. 用例运行开始时记录运行的线程id
        def startTest(self, test: unittest.case.TestCase):
            """单个用例执行开始"""
            test.result = TestCaseResult(test)
            self.testcase_results.append(test.result)
            test.result.thread_id = str(threading.current_thread().ident)  # 线程id
            test.result.start_at = time.time()
            self.testsRun += 1
            # self._mirrorOutput = False
            self._setupStdout()
    
    1. 用例运行结束时从_restoreStdout返回的流数据中按用例线程id取出捕获的输出
        def stopTest(self, test: unittest.case.TestCase) -> None:
            """单个用例结束"""
            test.result.end_at = time.time()
            test.result.duration = test.result.end_at - test.result.start_at
    
            outputs, errors = self._restoreStdout()
            test_output = outputs.get(test.result.thread_id, '') + errors.get(test.result.thread_id, '')
            test.result.output = test_output.strip()
            ...
    

    完整代码

    import inspect
    import platform
    import sys
    import time
    import traceback
    import unittest
    import io
    from unittest.result import failfast
    import re
    import threading
    from collections import defaultdict
    
    from logz import log
    log.format = '%(asctime)s %(threadName)s %(thread)d %(levelname)s %(message)s'
    
    
    TAG_PARTTEN = 'tag:(w+)'
    LEVEL_PARTTEN = 'level:(d+)'
    
    
    def get_case_tags(case: unittest.TestCase) -> list:
        """从用例方法的docstring中匹配出指定格式的tags"""
        case_tags = None
        case_doc = case._testMethodDoc
        if case_doc and 'tag' in case_doc:
            pattern = re.compile(TAG_PARTTEN)
            case_tags = re.findall(pattern, case_doc)
        return case_tags
    
    
    def get_case_level(case: unittest.TestCase):
        """从用例方法的docstring中匹配出指定格式的level"""
        case_doc = case._testMethodDoc
        case_level = None  # todo 默认level
        if case_doc:
            pattern = re.compile(LEVEL_PARTTEN)
            levels = re.findall(pattern, case_doc)
            if levels:
                case_level = levels[0]
                try:
                    case_level = int(case_level)
                except:
                    raise ValueError(f'用例中level设置:{case_level} 应为整数格式')
        return case_level
    
    
    class TestStatus(object):
        SUCCESS = 'success'
        FAIL = 'fail'
        ERROR = 'error'
        SKIPPED = 'skipped'
        XFAIL = 'xfail'
        XPASS = 'xpass'
    
    
    def time_to_string(timestamp: float) -> str:
        """时间戳转时间字符串"""
        time_array = time.localtime(timestamp)
        time_str = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
        return time_str
    
    
    def get_platform_info():
        """获取执行平台信息"""
        return {
            "platform": platform.platform(),
            "system": platform.system(),
            "python_version": platform.python_version(),
            # "env": dict(os.environ),  # 可能包含敏感信息
        }
    
    
    def inspect_code(test):
        test_method = getattr(test.__class__, test._testMethodName)
        try:
            code = inspect.getsource(test_method)
        except Exception as ex:
            log.exception(ex)
            code = ''
        return code
    
    
    class ThreadStringIO(io.StringIO):
        """按线程存取结果"""
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.values = defaultdict(str)  # 按线程号存储
    
        def write(self, text):
            # log.info('线程输出', text)
            thread_id = str(threading.current_thread().ident)
            super().write(text)
            self.values[thread_id] += text
    
        def writelines(self, lines):
            super().writelines(lines)
            thread_id = str(threading.current_thread().ident)
            text = ''.join(lines)
            self.values[thread_id] += text
    
        def getvalues(self):
            """返回线程结果字典"""
            return self.values
    
    
    class TestCaseResult(object):
        """用例测试结果"""
    
        def __init__(self, test: unittest.case.TestCase, name=None):
            self.test = test  # 测试用例对象
    
            self.name = name or test._testMethodName  # 支持传入用例别名,unittest.TestCase自带属性方法
            self.id = test.id()  # 用例完整路径,unittest.TestCase自带方法
            self.description = test.shortDescription()  # 用例简要描述,unittest.TestCase自带方法
            self.doc = test._testMethodDoc  # 用例docstring,,unittest.TestCase自带属性方法
            self.module_name = test.__module__  # 用例所在模块名
            self.class_name = test.__class__.__name__  # 用例所在类名
            self.class_id = f'{test.__module__}.{test.__class__.__name__}'  # 用例所在类完整路径
            self.class_doc = test.__class__.__doc__  # 用例所在类docstring描述
    
            self.tags = get_case_tags(test)  # 获取用例tags
            self.level = get_case_level(test)  # 获取用例level等级
            self.code = inspect_code(test)  # 获取用例源代码
    
            # 用例执后更新的信息
            self.start_at = None  # 用例开始时间
            self.end_at = None  # 用例结束时间
            self.duration = None  # 用例执行持续时间
    
            self.status = None  # 用例测试状态
            self.output = None  # 用例内的print信息
            self.exc_info = None  # 用例异常信息
            self.reason = None  # 跳过,失败,出错原因
    
        @property
        def data(self):  # 组合字典格式的用例结果数据
            data = dict(
                name=self.name,
                id=self.test.id(),
                description=self.description,
                status=self.status,
                tags=self.tags,
                level=self.level,
                time=dict(
                    start_at=self.start_at,
                    end_at=self.end_at,
                    duration=self.duration
                ),
                class_name=self.class_name,
                class_doc=self.class_doc,
                module_name=self.module_name,
                code=self.code,
                output=self.output,
                exc_info=self.exc_info,
                reason=self.reason,
            )
            return data
    
    
    class TestResult(unittest.TestResult):
        """测试结果,补充整个过程的运行时间"""
    
        def __init__(self,
                     stream=None,
                     descriptions=None,
                     verbosity=None,
                     ):
    
            super().__init__(stream, descriptions, verbosity)
            self.successes = []
            self.testcase_results = []  # 执行的用例结果列表
            self.verbosity = verbosity or 1
            self.buffer = True
            self.know_exceptions = None
    
            self.name = None
            self.start_at = None
            self.end_at = None
            self.duration = None
            self.successes_count = 0
            self.failures_count = 0
            self.errors_count = 0
            self.skipped_count = 0
            self.expectedFailures_count = 0
            self.unexpectedSuccesses_count = 0
    
            self._stdout_buffers = dict() # 按线程存
            self._stderr_buffers = dict() # 按线程存
            self.threadId = str(threading.current_thread().ident)
    
    
        @property
        def summary(self):
            """组装结果概要, details分按运行顺序和按类组织两种结构"""
            data = dict(
                name=self.name,
                success=self.wasSuccessful(),
                stat=dict(
                    testsRun=self.testsRun,
                    successes=self.successes_count,
                    failures=self.failures_count,
                    errors=self.errors_count,
                    skipped=self.skipped_count,
                    expectedFailures=self.expectedFailures_count,
                    unexpectedSuccesses=self.unexpectedSuccesses_count,
                ),
                time=dict(
                    start_at=self.start_at,
                    end_at=self.end_at,
                    duration=self.duration
                ),
                platform=get_platform_info(),  # 环境信息的最后状态
                details=[testcase_result.data for testcase_result in self.testcase_results]
            )
            return data
    
        def _setupStdout(self):
            """分线程存储"""
            # rlock.acquire()
            # if self._stdout_buffers.get(self.threadId) is None:
            #     self._stdout_buffers[self.threadId] = io.StringIO()
            #     self._stderr_buffers[self.threadId] = io.StringIO()
            #     sys.stdout = self._stdout_buffers[self.threadId]
            #     sys.stderr = self._stderr_buffers[self.threadId]
    
            if self.buffer:
                if self._stderr_buffer is None:
                    # self._stderr_buffer = io.StringIO()
                    # self._stdout_buffer = io.StringIO()
                    self._stderr_buffer = ThreadStringIO()
                    self._stdout_buffer = ThreadStringIO()
                sys.stdout = self._stdout_buffer
                sys.stderr = self._stderr_buffer
    
        def _restoreStdout(self):
            """重写父类的_restoreStdout方法"""
            if self.buffer:
                # if self._mirrorOutput:
                #     output = sys.stdout.getvalue()
                #     error = sys.stderr.getvalue()
                outputs = self._stdout_buffer.getvalues()
                errors = self._stderr_buffer.getvalues()
    
                sys.stdout = self._original_stdout
                sys.stdout = self._original_stderr
                self._stdout_buffer.seek(0)
                self._stdout_buffer.truncate()
                self._stderr_buffer.seek(0)
                self._stderr_buffer.truncate()
                return outputs, errors
    
        def _get_exc_msg(self, err):
            exctype, value, tb = err
            exc_msg = str(value)
            exc_full_path = f'{exctype.__module__}.{exctype.__name__}'
            if self.know_exceptions and isinstance(self.know_exceptions, dict):
                exc_msg = self.know_exceptions.get(exc_full_path, exc_msg)
            return exc_msg
    
        def _exc_info_to_string(self, err, test):
            """重写父类的转换异常方法, 去掉buffer的输出"""
            exctype, value, tb = err
            while tb and self._is_relevant_tb_level(tb):
                tb = tb.tb_next
    
            if exctype is test.failureException:
                # Skip assert*() traceback levels
                length = self._count_relevant_tb_levels(tb)
            else:
                length = None
            tb_e = traceback.TracebackException(
                exctype, value, tb, limit=length, capture_locals=self.tb_locals)
            msgLines = list(tb_e.format())
            return ''.join(msgLines)
    
        def startTestRun(self):
            """整个执行开始"""
            self.start_at = time.time()
            if self.verbosity > 1:
                log.info(f'===== 测试开始, 开始时间: {time_to_string(self.start_at)} =====')
    
        def stopTestRun(self):
            """整个执行结束"""
            self.end_at = time.time()
            self.duration = self.end_at - self.start_at
            self.success = self.wasSuccessful()
            if self.verbosity > 1:
                log.info(f'===== 测试结束, 持续时间: {self.duration}秒 =====')
    
            # log.info(self.summary['stat'])
            # log.info(self.summary['time'])
    
        def startTest(self, test: unittest.case.TestCase):
            """单个用例执行开始"""
            test.result = TestCaseResult(test)
            self.testcase_results.append(test.result)
            test.result.thread_id = str(threading.current_thread().ident)  # 线程id
            test.result.start_at = time.time()
            self.testsRun += 1
            # self._mirrorOutput = False
            self._setupStdout()
    
            if self.verbosity > 1:
                log.info(f'执行用例: {test.result.name}: {test.result.description}, 开始时间: {time_to_string(test.result.start_at)}')
    
        def stopTest(self, test: unittest.case.TestCase) -> None:
            """单个用例结束"""
            test.result.end_at = time.time()
            test.result.duration = test.result.end_at - test.result.start_at
    
            outputs, errors = self._restoreStdout()
            test_output = outputs.get(test.result.thread_id, '') + errors.get(test.result.thread_id, '')
            test.result.output = test_output.strip()
    
            if self.verbosity > 1:
                log.info(f'结果: {test.result.status}, 持续时间: {test.result.duration}秒')
            elif self.verbosity > 0:
                log.info(f'{test.result.name} ...  {test.result.status}')
    
            if self.verbosity > 0:
                if test.result.output:
                    log.info(f'{test.result.output.strip()}')
    
                if test.result.exc_info:
                    log.exception(test.result.exc_info)
    
        def addSuccess(self, test):
            """重写父类方法, 单个用例成功时在stopTest前调用"""
            test.result.status = TestStatus.SUCCESS
            self.successes.append(test)
            self.successes_count += 1
            # super().addSuccess(test)
    
        @failfast
        def addFailure(self, test, err):
            """重写父类方法, 用例失败时在stopTest前调用"""
            test.result.status = TestStatus.FAIL
            test.result.exc_info = self._exc_info_to_string(err, test)
            test.result.reason = self._get_exc_msg(err)
            self.failures_count += 1
            # super().addFailure(test, err)
            self.failures.append((test, self._exc_info_to_string(err, test)))
            # self._mirrorOutput = True
    
        @failfast
        def addError(self, test, err):
            """重写父类方法, 用例异常时在stopTest前调用"""
            test.result.status = TestStatus.ERROR
            test.result.exc_info = self._exc_info_to_string(err, test)
            test.result.reason = self._get_exc_msg(err)
            self.errors_count += 1
            # super().addError(test, err)
            self.errors.append((test, self._exc_info_to_string(err, test)))
            # self._mirrorOutput = True
    
        def addSkip(self, test, reason):
            """重写父类方法, 用例跳过时在stopTest前调用"""
            test.result.status = TestStatus.SKIPPED
            test.result.reason = reason
            self.skipped_count += 1
            # super().addSkip(test, reason)
            self.skipped.append((test, reason))
    
        def addExpectedFailure(self, test, err):
            """重写父类方法, 用例期望失败时在stopTest前调用"""
            test.result.status = TestStatus.XFAIL
            test.result.exc_info = self._exc_info_to_string(err, test)
            test.result.reason = self._get_exc_msg(err)
            self.expectedFailures_count += 1
            # super().addExpectedFailure(test, err)
            self.expectedFailures.append(
                (test, self._exc_info_to_string(err, test)))
    
        @failfast
        def addUnexpectedSuccess(self, test):
            """重写父类方法, 用例非期望成功时在stopTest前调用"""
            test.result.status = TestStatus.XPASS
            self.expectedFailures_count += 1
            # super().addUnexpectedSuccess(test)
            self.unexpectedSuccesses.append(test)
    
  • 相关阅读:
    jzoj 3176. 【GDOI2013模拟5】蜘蛛侠
    各种各样的根号算法 总结&刷题
    jzoj 3187. 【GDOI2013模拟8】的士
    jzoj 3188. 【GDOI2013模拟8】找数
    jzoj 4673. 【NOIP2016提高A组模拟7.20】LCS again
    jzoj 4672. 【NOIP2016提高A组模拟7.20】Graph Coloring
    markdown 模板2
    树莓派kali开启arp【arpspoof,urlsnarf】
    Java 图片处理——如何生成高清晰度而占有磁盘小的缩略图
    手把手教你生成二维码-google.zxing
  • 原文地址:https://www.cnblogs.com/superhin/p/13755684.html
Copyright © 2011-2022 走看看