zoukankan      html  css  js  c++  java
  • 网络结构可视化方法

    方法一:输出为PDF文档(使用graphviz)

    from graphviz import Digraph
    import torch
    from torch.autograd import Variable
    
    
    def make_dot(var, params=None):
        """ Produces Graphviz representation of PyTorch autograd graph
        Blue nodes are the Variables that require grad, orange are Tensors
        saved for backward in torch.autograd.Function
        Args:
            var: output Variable
            params: dict of (name, Variable) to add names to node that
                require grad (TODO: make optional)
        """
        if params is not None:
            assert isinstance(params.values()[0], Variable)
            param_map = {id(v): k for k, v in params.items()}
    
        node_attr = dict(style='filled',
                         shape='box',
                         align='left',
                         fontsize='12',
                         ranksep='0.1',
                         height='0.2')
        dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
        seen = set()
    
        def size_to_str(size):
            return '('+(', ').join(['%d' % v for v in size])+')'
    
        def add_nodes(var):
            if var not in seen:
                if torch.is_tensor(var):
                    dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
                elif hasattr(var, 'variable'):
                    u = var.variable
                    name = param_map[id(u)] if params is not None else ''
                    node_name = '%s
     %s' % (name, size_to_str(u.size()))
                    dot.node(str(id(var)), node_name, fillcolor='lightblue')
                else:
                    dot.node(str(id(var)), str(type(var).__name__))
                seen.add(var)
                if hasattr(var, 'next_functions'):
                    for u in var.next_functions:
                        if u[0] is not None:
                            dot.edge(str(id(u[0])), str(id(var)))
                            add_nodes(u[0])
                if hasattr(var, 'saved_tensors'):
                    for t in var.saved_tensors:
                        dot.edge(str(id(t)), str(id(var)))
                        add_nodes(t)
        add_nodes(var.grad_fn)
        return dot
    itLEP_pil, itLEP_np = get_image(real_face_name, imsize)
    net = skip(input_depth, itLEP_np.shape[0], 
               num_channels_down = [128] * 5,
               num_channels_up =   [128] * 5,
               num_channels_skip =    [128] * 5,
               filter_size_up = 3, filter_size_down = 3,
               upsample_mode='nearest', filter_skip_size=1,
               need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype)
    
    dummy_input = get_noise(input_depth, INPUT, itLEP_np.shape[1:]).type(dtype)
    #上面为定义网络结构,以及定义输入;下面为输出网络结构图 y
    = net(dummy_input) g = make_dot(y) g.view()

    方法二:使用tensorboardX

    import torch
    import torch.nn as nn
    from tensorboardX import SummaryWriter
    class LeNet(nn.Module):
        def __init__(self):
            super(LeNet, self).__init__()
            self.conv1 = nn.Sequential(     #input_size=(1*28*28)
                nn.Conv2d(1, 6, 5, 1, 2),
                nn.ReLU(),      #(6*28*28)
                nn.MaxPool2d(kernel_size=2, stride=2),  #output_size=(6*14*14)
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(6, 16, 5),
                nn.ReLU(),      #(16*10*10)
                nn.MaxPool2d(2, 2)  #output_size=(16*5*5)
            )
            self.fc1 = nn.Sequential(
                nn.Linear(16 * 5 * 5, 120),
                nn.ReLU()
            )
            self.fc2 = nn.Sequential(
                nn.Linear(120, 84),
                nn.ReLU()
            )
            self.fc3 = nn.Linear(84, 10)
    
        # 定义前向传播过程,输入为x
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
            x = x.view(x.size()[0], -1)
            x = self.fc1(x)
            x = self.fc2(x)
            x = self.fc3(x)
            return x
    
    dummy_input = torch.rand(13, 1, 28, 28) #假设输入13张1*28*28的图片
    model = LeNet()
    with SummaryWriter(comment='LeNet') as w:
        w.add_graph(model, (dummy_input, ))

     这里运行后会生成runs文件夹,切换到runs所在的目录,

    使用 tensorboard --logdir runs该命令,得到浏览器地址,在不同的浏览器打开(因为有些浏览器打开看不到任何东西)

    双击图的结构,出现网络细节图

  • 相关阅读:
    require.js使用
    favico是针对网页图标内容更改
    web图片转换小工具制作
    控制显示input隐藏和查看密码
    程序员图片注释字符串制作工具
    c语言基础, , ,
    【理解】column must appear in the GROUP BY clause or be used in an aggregate function
    ps aux命令解析
    while(std::cin>>val)怎么结束的思考
    【转】NativeScript的工作原理:用JavaScript调用原生API实现跨平台
  • 原文地址:https://www.cnblogs.com/hxjbc/p/10972092.html
Copyright © 2011-2022 走看看