zoukankan      html  css  js  c++  java
  • 【pytorch 代码】pytorch 网络结构可视化

    部分内容转载自 http://blog.csdn.net/GYGuo95/article/details/78821617,在此表示由衷感谢。

    此方法需要安装python-graphviz:  conda install -n pytorch python-graphviz 或者 sudo apt-get install graphviz 

    别忘了先把下面的代码下载到自己的路径(感谢大神)。

    visualize.py

    from graphviz import Digraph
    import torch
    from torch.autograd import Variable
    
    
    def make_dot(var, params=None):
        """ Produces Graphviz representation of PyTorch autograd graph
        Blue nodes are the Variables that require grad, orange are Tensors
        saved for backward in torch.autograd.Function
        Args:
            var: output Variable
            params: dict of (name, Variable) to add names to node that
                require grad (TODO: make optional)
        """
        if params is not None:
            assert isinstance(params.values()[0], Variable)
            param_map = {id(v): k for k, v in params.items()}
    
        node_attr = dict(style='filled',
                         shape='box',
                         align='left',
                         fontsize='12',
                         ranksep='0.1',
                         height='0.2')
        dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
        seen = set()
    
        def size_to_str(size):
            return '('+(', ').join(['%d' % v for v in size])+')'
    
        def add_nodes(var):
            if var not in seen:
                if torch.is_tensor(var):
                    dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
                elif hasattr(var, 'variable'):
                    u = var.variable
                    name = param_map[id(u)] if params is not None else ''
                    node_name = '%s
     %s' % (name, size_to_str(u.size()))
                    dot.node(str(id(var)), node_name, fillcolor='lightblue')
                else:
                    dot.node(str(id(var)), str(type(var).__name__))
                seen.add(var)
                if hasattr(var, 'next_functions'):
                    for u in var.next_functions:
                        if u[0] is not None:
                            dot.edge(str(id(u[0])), str(id(var)))
                            add_nodes(u[0])
                if hasattr(var, 'saved_tensors'):
                    for t in var.saved_tensors:
                        dot.edge(str(id(t)), str(id(var)))
                        add_nodes(t)
        add_nodes(var.grad_fn)
        return dot

    下面是使用方法:

    因人而异,根据网络调整输入,以Inception V3为例。

    from MyInceptionV3 import inception_v3
    import numpy as np
    import torch
    from torch.autograd import Variable
    from visualize import make_dot
    
    if __name__ == '__main__':
        x=np.arange(2*299*299*3)
        x=x.reshape(2,3,299,299)
        x=x/float(x.max())
        x=torch.from_numpy(x)
        x=x.float()
        x=Variable(x)
    
        a = inception_v3(pretrained=True)
    
        y = a(x)
        g = make_dot(y)
        #g.view()    
        g.render('here', view=False)

    我的电脑没有可视化界面,一定要记得False那个view。(网络结构会保存成pdf)结果图太复杂不粘贴了。

  • 相关阅读:
    linux嵌入式系统交叉开发环境
    Codeforces Round #208 E. Dima and Kicks
    mvn 编译错误java.lang.NoSuchMethodError: org.objectweb.asm.ClassWriter. <init>(Z)V
    黑马程序员_<<TCP>>
    微信/易信公共平台开发(四):公众号调试器 (仿真微信平台,提供PHP源码)
    用pdb调试OpenStack Havana
    MySql Odbc等驱动下载地址分享下
    导入exce表格中的数据l到数据库
    关闭数据备份信息写入数据库日志
    SQL Server之RAID简介
  • 原文地址:https://www.cnblogs.com/xiangfeidemengzhu/p/8528802.html
Copyright © 2011-2022 走看看