zoukankan      html  css  js  c++  java
  • 魔改keras中的model.summary()

    之前做项目,好奇这个函数是怎么实现的,我把源码看了一遍,魔改代码,把没用的删除,重新封装为一个类,还加上了可以输出至txt的功能
    '''
    class print_summary_magic_modification:
    def init(self, model, file_path):
    self.model = model
    self.file_path = file_path

    def params_nums(weights):
        return int(np.sum([K.count_params(p) for p in set(weights)]))
    
    
    def print_row(self, fields, positions):
        line = ''
        for i in range(len(fields)):
            if i > 0:
                line = line[:-1] + ' '
            line += str(fields[i])
            line = line[:positions[i]]
            line += ' ' * (positions[i] - len(line))
        print(line)
    
    def print_layer_summary(self, layer, positions):
        try:
            output_shape = layer.output_shape
        except AttributeError:
            output_shape = 'multiple'
        name = layer.name
        cls_name = layer.__class__.__name__
        fields = [name + ' (' + cls_name + ')',
                  output_shape, layer.count_params()]
        self.print_row(fields, positions)
    
    def print_summary(self):
        """Prints a summary of a model.
        """
        line_length = 65
        positions = [29, 55, 100]
    
        # header names for the different log elements
        to_display = ['Layer (type)', 'Output Shape', 'Param #']
    
        print('_' * line_length)
        self.print_row(to_display, positions)
        print('=' * line_length)
    
        layers = self.model.layers
        for i in range(len(layers)):
            self.print_layer_summary(layers[i], positions)
            if i == len(layers) - 1:
                print('=' * line_length)
            else:
                print('_' * line_length)
    
    
    def print_summary2txt(self):
        """Prints a summary of a model.
        """
        with open(self.file_path, 'a', encoding='utf-8') as f:
    
            line_length = 65
            positions = [29, 55, 100]
    
            # header names for the different log elements
            to_display = ['Layer (type)', 'Output Shape', 'Param #']
    
            print('_' * line_length)
            self.print_row(to_display, positions)
            print('=' * line_length)
    
            layers = self.model.layers
            for i in range(len(layers)):
                self.print_layer_summary(layers[i], positions)
                if i == len(layers) - 1:
                    print('=' * line_length)
                else:
                    print('_' * line_length)
    

    '''

    下面这个功能可以直接使用model.summary()输出至txt文件,我在google中搜了好久找见的pythonic代码
    '''
    output_file_path = ''
    file_name = ''
    with open(output_file_path + file_name, 'w') as f:
    with redirect_stdout(f):
    model.summary()

    '''

  • 相关阅读:
    Jxl 简单运用 Excel创建,插入数据,图片,更新数据,
    tomcat端口号被占用
    QQ、MSN、淘包旺旺、Skype临时对话的html链接代码
    验证信息
    wpf学习笔记数据绑定功能总结
    wpfStyle注意点
    wpf轻量绘图DrawingVisual
    wpfDrawingBrush注意点
    wpf容易误解的Image
    wpf装饰器
  • 原文地址:https://www.cnblogs.com/joytribianni/p/12000781.html
Copyright © 2011-2022 走看看