zoukankan      html  css  js  c++  java
  • 文字检测模型EAST应用详解 ckpt pb的tf加载,opencv加载

    参考链接:https://github.com/argman/EAST (项目来源)

                      https://github.com/opencv/opencv/issues/12491  (遇到的问题)

          https://www.pyimagesearch.com/2018/08/20/opencv-text-detection-east-text-detector/   (opencv加载)

    文字检测有很多比较好的现成的模型比如yolov3,pesnet,pennet,east。不一一赘述,讲一下自己跑通east的过程。

    https://github.com/argman/EAST链接中下载项目,windows下,各种包的版本要正确否则会出一些乱七八糟的错误。

    运行EAST/eval.py。没有什么特别的问题要说,我在cpu下单张640*480的图能够达到每张0.4秒左右,还是非常优秀的。中英文数字都可。

    但是源代码是ckpt,非常大,转成pb会稍微小点。添加:

    ##生成pb模型,但需要修改model.py
    output_graph_def = tf.graph_util.convert_variables_to_constants(self.sess, # The session is used to retrieve the weights
    tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
    ["feature_fusion/Conv_7/Sigmoid", "feature_fusion/concat_3"]
    )
    output_graph='D:\work\video\hand_tracking_no_op\hand_tracking\EAST\east_icdar2015_resnet_v1_50_rbox\out.pb'
    with tf.gfile.GFile(output_graph, "wb") as f:
    f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))

    位置在eval.py中的

    saver.restore(self.sess, model_path)后面。注意如果你想要opencv加载pb还要修改model.py中的内容,这个在后面一篇文章中会讲到。
    生成后用tf加载,方法跟加载ckpt相似:

    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list

    try:
    os.makedirs(FLAGS.output_dir)
    except OSError as e:
    if e.errno != 17:
    raise

    print("load_graph")
    graph = load_graph(FLAGS.checkpoint_path)

    input_images = graph.get_tensor_by_name(
    'import/input_images:0')

    f_score = graph.get_tensor_by_name('import/feature_fusion/Conv_7/Sigmoid:0')
    f_geometry = graph.get_tensor_by_name(
    'import/feature_fusion/concat_3:0')

    with tf.Session(graph=graph) as sess:

    im_fn_list = get_images()
    for im_fn in im_fn_list:
    im = cv2.imread(im_fn)[:, :, ::-1]
    start_time = time.time()
    im_resized, (ratio_h, ratio_w) = resize_image(im)

    timer = {'net': 0, 'restore': 0, 'nms': 0}
    start = time.time()

    #file_writer = tf.summary.FileWriter('tmp/log', sess.graph)

    score, geometry = sess.run([f_score, f_geometry], feed_dict={
    input_images: [im_resized]})
    timer['net'] = time.time() - start

    boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer)
    print('{} : net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format(
    im_fn, timer['net']*1000, timer['restore']*1000, timer['nms']*1000))

    if boxes is not None:
    boxes = boxes[:, :8].reshape((-1, 4, 2))
    boxes[:, :, 0] /= ratio_w
    boxes[:, :, 1] /= ratio_h

    duration = time.time() - start_time
    print('[timing] {}'.format(duration))

    # save to file
    if boxes is not None:
    res_file = os.path.join(
    FLAGS.output_dir,
    '{}.txt'.format(
    os.path.basename(im_fn).split('.')[0]))

    with open(res_file, 'w') as f:
    for box in boxes:
    # to avoid submitting errors
    box = sort_poly(box.astype(np.int32))
    if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5:
    continue
    f.write('{},{},{},{},{},{},{},{} '.format(
    box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1],
    ))
    cv2.polylines(im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1)
    if not FLAGS.no_write_images:
    img_path = os.path.join(FLAGS.output_dir, os.path.basename(im_fn))
    cv2.imwrite(img_path, im[:, :, ::-1])

    以上就是EAST的ckpt转pb用tf加载啦。
    下一篇讲opencv加载east的pb。



  • 相关阅读:
    js 递归获取多层树的某个节点
    layui table 打印表格
    tp6 使用queue
    url带参数生成二维码
    redis的常用配置
    《TensorFlow实战》中AlexNet卷积神经网络的训练中
    JavaScript之闭包
    JavaScript之map与parseInt的陷阱
    JavaScript方法中this关键字使用注意
    什么是深度学习?
  • 原文地址:https://www.cnblogs.com/zwczp/p/12769222.html
Copyright © 2011-2022 走看看