zoukankan      html  css  js  c++  java
  • 大数据开发之keras代码框架应用

         总体来讲keras这个深度学习框架真的很“简易”,它体现在可参考的文档写的比较详细,不像caffe,装完以后都得靠技术博客,keras有它自己的官方文档(不过是英文的),这给初学者提供了很大的学习空间。

        在此做下代码框架应用笔记

         

    class VGGNetwork:
        def append_vgg_network(self, x_in, true_X_input):
            return x #x is output of VGG
        def load_vgg_weight(self, model):
            return model
    class DiscriminatorNetwork:
        def append_gan_network(self, true_X_input):
            return x
    class GenerativeNetwork:
        def create_sr_model(self, ip):
            return x
        def get_generator_output(self, input_img, srgan_model):
            return self.output_func([input_img])
    class SRGANNetwork:
        def build_srgan_pretrain_model(self):
            return self.srgan_model_
        def build_discriminator_pretrain_model(self):
            return self.discriminative_model_
        def build_srgan_model(self):
            return self.srgan_model_
        def pre_train_srgan(self, image_dir, nb_images=50000, nb_epochs=1, use_small_srgan=False):
            for i in range(nb_epochs):
                for x in datagen.flow_from_directory
                    if iteration % 50 == 0 and iteration != 0
                        validation//print psnr
                    Train only generator + vgg network
                    if iteration % 1000 == 0 and iteration != 0
                        Saving model weights
        def pre_train_discriminator(self, image_dir, nb_images=50000, nb_epochs=1, batch_size=128): 
            for i in range(nb_epochs):
                 for x in datagen.flow_from_directory
                     Train only discriminator
                     if iteration % 1000 == 0 and iteration != 0
                        Saving model weights
        def train_full_model(self, image_dir, nb_images=50000, nb_epochs=10):   
            for i in range(nb_epochs):
                for x in datagen.flow_from_directory
                    if iteration % 50 == 0 and iteration != 0
                        validation//print psnr
                    if iteration % 1000 == 0 and iteration != 0
                        Saving model weights
                    Train only discriminator, disable training of srgan
                    Train only generator, disable training of discriminator
    if __name__ == "__main__":
        from keras.utils.visualize_util import plot
    
        # Path to MS COCO dataset
        coco_path = r"D:YueDocumentsDatasetcoco2014	rain2014"
    
        '''
        Base Network manager for the SRGAN model
    
        Width / Height = 32 to reduce the memory requirement for the discriminator.
    
        Batch size = 1 is slower, but uses the least amount of gpu memory, and also acts as
        Instance Normalization (batch norm with 1 input image) which speeds up training slightly.
        '''
    
        srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
        srgan_network.build_srgan_model()
        #plot(srgan_network.srgan_model_, 'SRGAN.png', show_shapes=True)
    
        # Pretrain the SRGAN network
        #srgan_network.pre_train_srgan(coco_path, nb_images=80000, nb_epochs=1)
    
        # Pretrain the discriminator network
        #srgan_network.pre_train_discriminator(coco_path, nb_images=40000, nb_epochs=1, batch_size=16)
    
        # Fully train the SRGAN with VGG loss and Discriminator loss
        srgan_network.train_full_model(coco_path, nb_images=80000, nb_epochs=5)
    

      

  • 相关阅读:
    CGI编程完全手册(转)
    Linux下读写芯片的I2C寄存器(转)
    Linux内核中_IO,_IOR,_IOW,_IOWR宏的用法与解析
    H264中的SPS、PPS提取与作用(转)
    H264码流打包分析(精华)--转
    嵌入式Linux USB WIFI驱动的移植(转)
    推荐一款技术人必备的接口测试神器:Apifox
    Java 设置、删除、获取Word文档背景(基于Spire.Cloud.SDK for Java)
    Java 添加、删除、格式化Word中的图片( 基于Spire.Cloud.SDK for Java )
    Java 添加、删除、替换、格式化Word中的文本(基于Spire.Cloud.SDK for Java)
  • 原文地址:https://www.cnblogs.com/68xi/p/8590600.html
Copyright © 2011-2022 走看看