zoukankan      html  css  js  c++  java
  • 吴裕雄 PYTHON 人工智能——基于MASK_RCNN目标检测(5)

    import os
    import sys
    import numpy as np
    import tensorflow as tf
    import matplotlib
    import matplotlib.pyplot as plt
    import keras
    
    import utils
    import model as modellib
    import visualize
    from model import log
    
    %matplotlib inline 
    
    # Root directory of the project
    ROOT_DIR = os.getcwd()
    
    # Directory to save logs and trained model
    MODEL_DIR = os.path.join(ROOT_DIR, "logs")
    
    # Local path to trained weights file
    COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
    # Download COCO trained weights from Releases if needed
    if not os.path.exists(COCO_MODEL_PATH):
        utils.download_trained_weights(COCO_MODEL_PATH)
    
    # Path to Shapes trained weights
    SHAPES_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_shapes.h5")
    # Run one of the code blocks
    
    # Shapes toy dataset
    # import shapes
    # config = shapes.ShapesConfig()
    
    # MS COCO Dataset
    import coco
    config = coco.CocoConfig()
    # Device to load the neural network on.
    # Useful if you're training a model on the same 
    # machine, in which case use CPU and leave the
    # GPU for training.
    DEVICE = "/cpu:0"  # /cpu:0 or /gpu:0
    def get_ax(rows=1, cols=1, size=16):
        """Return a Matplotlib Axes array to be used in
        all visualizations in the notebook. Provide a
        central point to control graph sizes.
        
        Adjust the size attribute to control how big to render images
        """
        _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
        return ax
    # Create model in inference mode
    with tf.device(DEVICE):
        model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
                                  config=config)
    
    # Set weights file path
    if config.NAME == "shapes":
        weights_path = SHAPES_MODEL_PATH
    elif config.NAME == "coco":
        weights_path = COCO_MODEL_PATH
    # Or, uncomment to load the last model you trained
    # weights_path = model.find_last()[1]
    
    # Load weights
    print("Loading weights ", weights_path)
    model.load_weights(weights_path, by_name=True)
    # Show stats of all trainable weights    
    visualize.display_weight_stats(model)

    # Pick layer types to display
    LAYER_TYPES = ['Conv2D', 'Dense', 'Conv2DTranspose']
    # Get layers
    layers = model.get_trainable_layers()
    layers = list(filter(lambda l: l.__class__.__name__ in LAYER_TYPES, 
                    layers))
    # Display Histograms
    fig, ax = plt.subplots(len(layers), 2, figsize=(10, 3*len(layers)),
                           gridspec_kw={"hspace":1})
    for l, layer in enumerate(layers):
        weights = layer.get_weights()
        for w, weight in enumerate(weights):
            tensor = layer.weights[w]
            ax[l, w].set_title(tensor.name)
            _ = ax[l, w].hist(weight[w].flatten(), 50)

  • 相关阅读:
    前端学PHP之面向对象系列第一篇——类和对象
    前端学PHP之面向对象系列第三篇——三大特性
    mysql之视图
    PHP中冒号、endif、endwhile、endfor这些都是什么
    abap将内表数据导出为excel文件
    调用、跳转到事务码、报表
    <转>sapjco3.dll already loaded in another classloader错误解决办法
    sap权限检查
    <转>DIALOG条件 CHAIN 语句
    ABAP DOI 下载SMW0的EXCEL和WORD模板
  • 原文地址:https://www.cnblogs.com/tszr/p/10868171.html
Copyright © 2011-2022 走看看