zoukankan      html  css  js  c++  java
  • Airtest源码分析--图像识别整体流程

    上期回顾:Airtest-API精讲之Template


    以下基于
    python3.8;airtest1.2.2;pocoui1.0.83

    之前讲了图像识别的基础——Template类:Airtest-API精讲之Template
    这次我们看下Airtest图像识别的整体流程。

    我们以touch()接口为例,AirtestIDE中touch怎么用可以看:AirtestIDE基本功能(一)

    进入查看touch源码

    # 源码路径 your_python_path/site-packages/airtest/core/api.py
    def touch(v, times=1, **kwargs):
        """
        Perform the touch action on the device screen

        :param v: target to touch, either a ``Template`` instance or absolute coordinates (x, y)
        :param times: how many touches to be performed
        :param kwargs: platform specific `kwargs`, please refer to corresponding docs
        :return: finial position to be clicked, e.g. (100, 100)
        """
        if isinstance(v, Template):
            pos = loop_find(v, timeout=ST.FIND_TIMEOUT)
        else:
            try_log_screen()
            pos = v
        for _ in range(times):
            G.DEVICE.touch(pos, **kwargs)
            time.sleep(0.05)
        delay_after_operation()
        return pos

    touch是兼容传入图片或坐标的,我们只看图片的逻辑。

    pos = loop_find(v, timeout=ST.FIND_TIMEOUT)

    可以看到是通过loop_find去循环找图,超时时间ST.FIND_TIMEOUT默认是20S,这里找到图片的话会返回坐标,后面的代码会去点击这个坐标,就完成了touch操作。

    继续进入loop_find源码:

    # 源码路径 your_python_path/site-packages/airtest/core/cv.py
    def loop_find(query, timeout=ST.FIND_TIMEOUT, threshold=None, interval=0.5, intervalfunc=None):
        G.LOGGING.info("Try finding: %s", query)
        start_time = time.time()
        while True:
            screen = G.DEVICE.snapshot(filename=None, quality=ST.SNAPSHOT_QUALITY)

            if screen is None:
                G.LOGGING.warning("Screen is None, may be locked")
            else:
                if threshold:
                    query.threshold = threshold
                match_pos = query.match_in(screen)
                if match_pos:
                    try_log_screen(screen)
                    return match_pos

            if intervalfunc is not None:
                intervalfunc()

            # 超时则raise,未超时则进行下次循环:
            if (time.time() - start_time) > timeout:
                try_log_screen(screen)
                raise TargetNotFoundError('Picture %s not found in screen' % query)
            else:
                time.sleep(interval)

    loop_find整体逻辑就是循环去屏幕截图上找图,找到返回其坐标,超时未找到报错。第1个参数query就是我们前面传入的Template类实例(我们截的图)

    其中关键是match_pos = query.match_in(screen),前一步给手机截图赋值给screen,然后在截图中查找给定图片,用的方法是Template类中的match_in方法。

    继续看match_in源码:

    # 源码路径 your_python_path/site-packages/airtest/core/cv.py
    def match_in(self, screen):
        match_result = self._cv_match(screen)
        G.LOGGING.debug("match result: %s", match_result)
        if not match_result:
            return None
        focus_pos = TargetPos().getXY(match_result, self.target_pos)
        return focus_pos

    其中核心代码是match_result = self._cv_match(screen)图像匹配

    如果找到后面代码会返回9宫点中我们要求的坐标:

    focus_pos = TargetPos().getXY(match_result, self.target_pos)

    还得记得9宫点吗?就是Template实例化时我们指定的target_pos,忘了可以看这篇Airtest-API精讲之Template中的target_pos

    继续看_cv_match源码:

    # 源码路径 your_python_path/site-packages/airtest/core/cv.py
        def _cv_match(self, screen):
            # in case image file not exist in current directory:
            ori_image = self._imread()
            image = self._resize_image(ori_image, screen, ST.RESIZE_METHOD)
            ret = None
            for method in ST.CVSTRATEGY:
                # get function definition and execute:
                func = MATCHING_METHODS.get(method, None)
                if func is None:
                    raise InvalidMatchingMethodError("Undefined method in CVSTRATEGY: '%s', try 'kaze'/'brisk'/'akaze'/'orb'/'surf'/'sift'/'brief' instead." % method)
                else:
                    if method in ["mstpl", "gmstpl"]:
                        ret = self._try_match(func, ori_image, screen, threshold=self.threshold, rgb=self.rgb, record_pos=self.record_pos,resolution=self.resolution, scale_max=self.scale_max, scale_step=self.scale_step)
                    else:
                        ret = self._try_match(func, image, screen, threshold=self.threshold, rgb=self.rgb)
                if ret:
                    break
            return ret

    其中ori_image = self._imread()读取图像

    image = self._resize_image(ori_image, screen, ST.RESIZE_METHOD)

    根据分辨率,将输入的截图适配成 等待模板匹配的截图

    之后会循环各种算法去匹配图片,默认算法为ST.CVSTRATEGY = ["mstpl", "tpl", "surf", "brisk"]

    循环中用到的匹配方法为_try_match

    继续看_try_match源码:

    # 源码路径 your_python_path/site-packages/airtest/core/cv.py
        def _try_match(func, *args, **kwargs):
            G.LOGGING.debug("try match with %s" % func.__name__)
            try:
                ret = func(*args, **kwargs).find_best_result()
            except aircv.NoModuleError as err:
                G.LOGGING.warning("'surf'/'sift'/'brief' is in opencv-contrib module. You can use 'tpl'/'kaze'/'brisk'/'akaze'/'orb' in CVSTRATEGY, or reinstall opencv with the contrib module.")
                return None
            except aircv.BaseError as err:
                G.LOGGING.debug(repr(err))
                return None
            else:
                return ret

    其核心代码为ret = func(*args, **kwargs).find_best_result()

    不同的算法对应不同的find_best_result()方法,目前一共有4种,我们以TemplateMatching类中的为例看一下

    # 源码路径 your_python_path/site-packages/airtest/aircv/template_matching.py
    def find_best_result(self):
        """基于kaze进行图像识别,只筛选出最优区域."""
        """函数功能:找到最优结果."""
        # 第一步:校验图像输入
        check_source_larger_than_search(self.im_source, self.im_search)
        # 第二步:计算模板匹配的结果矩阵res
        res = self._get_template_result_matrix()
        # 第三步:依次获取匹配结果
        min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
        h, w = self.im_search.shape[:2]
        # 求取可信度:
        confidence = self._get_confidence_from_matrix(max_loc, max_val, w, h)
        # 求取识别位置: 目标中心 + 目标区域:
        middle_point, rectangle = self._get_target_rectangle(max_loc, w, h)
        best_match = generate_result(middle_point, rectangle, confidence)
        LOGGING.debug("[%s] threshold=%s, result=%s" % (self.METHOD_NAME, self.threshold, best_match))

        return best_match if confidence >= self.threshold else None

    到这里就是基于cv2库去找图了,步骤注释写的很清楚了。对opencv感兴趣的同学,可以到这里学一学http://www.woshicver.com/

  • 相关阅读:
    为什么有时候程序出问题会打印出“烫烫烫烫...
    VC++共享数据段实现进程之间共享数据
    IEEE浮点数float、double的存储结构
    前端智勇大闯关
    Python:高级主题之(属性取值和赋值过程、属性描述符、装饰器)
    来认识下less css
    Koala Framework
    在使用Kettle的集群排序中 Carte的设定——(基于Windows)
    标准库类型
    iOS多线程的初步研究1
  • 原文地址:https://www.cnblogs.com/songzhenhua/p/15365791.html
Copyright © 2011-2022 走看看