zoukankan      html  css  js  c++  java
  • 大数据开发之keras代码框架应用

         总体来讲keras这个深度学习框架真的很“简易”,它体现在可参考的文档写的比较详细,不像caffe,装完以后都得靠技术博客,keras有它自己的官方文档(不过是英文的),这给初学者提供了很大的学习空间。

        在此做下代码框架应用笔记

         

    class VGGNetwork:
        def append_vgg_network(self, x_in, true_X_input):
            return x #x is output of VGG
        def load_vgg_weight(self, model):
            return model
    class DiscriminatorNetwork:
        def append_gan_network(self, true_X_input):
            return x
    class GenerativeNetwork:
        def create_sr_model(self, ip):
            return x
        def get_generator_output(self, input_img, srgan_model):
            return self.output_func([input_img])
    class SRGANNetwork:
        def build_srgan_pretrain_model(self):
            return self.srgan_model_
        def build_discriminator_pretrain_model(self):
            return self.discriminative_model_
        def build_srgan_model(self):
            return self.srgan_model_
        def pre_train_srgan(self, image_dir, nb_images=50000, nb_epochs=1, use_small_srgan=False):
            for i in range(nb_epochs):
                for x in datagen.flow_from_directory
                    if iteration % 50 == 0 and iteration != 0
                        validation//print psnr
                    Train only generator + vgg network
                    if iteration % 1000 == 0 and iteration != 0
                        Saving model weights
        def pre_train_discriminator(self, image_dir, nb_images=50000, nb_epochs=1, batch_size=128): 
            for i in range(nb_epochs):
                 for x in datagen.flow_from_directory
                     Train only discriminator
                     if iteration % 1000 == 0 and iteration != 0
                        Saving model weights
        def train_full_model(self, image_dir, nb_images=50000, nb_epochs=10):   
            for i in range(nb_epochs):
                for x in datagen.flow_from_directory
                    if iteration % 50 == 0 and iteration != 0
                        validation//print psnr
                    if iteration % 1000 == 0 and iteration != 0
                        Saving model weights
                    Train only discriminator, disable training of srgan
                    Train only generator, disable training of discriminator
    if __name__ == "__main__":
        from keras.utils.visualize_util import plot
    
        # Path to MS COCO dataset
        coco_path = r"D:YueDocumentsDatasetcoco2014	rain2014"
    
        '''
        Base Network manager for the SRGAN model
    
        Width / Height = 32 to reduce the memory requirement for the discriminator.
    
        Batch size = 1 is slower, but uses the least amount of gpu memory, and also acts as
        Instance Normalization (batch norm with 1 input image) which speeds up training slightly.
        '''
    
        srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
        srgan_network.build_srgan_model()
        #plot(srgan_network.srgan_model_, 'SRGAN.png', show_shapes=True)
    
        # Pretrain the SRGAN network
        #srgan_network.pre_train_srgan(coco_path, nb_images=80000, nb_epochs=1)
    
        # Pretrain the discriminator network
        #srgan_network.pre_train_discriminator(coco_path, nb_images=40000, nb_epochs=1, batch_size=16)
    
        # Fully train the SRGAN with VGG loss and Discriminator loss
        srgan_network.train_full_model(coco_path, nb_images=80000, nb_epochs=5)
    

      

  • 相关阅读:
    HTML中的文本标签
    Java 数组的创建
    JavaScript实现LUHN算法验证银行卡号有效性
    JavaScript实现HTML页面集成QQ空间分享功能
    JavaScript中的三种弹出框的区别与使用
    Maven 项目中的 pom.xml 文件内容说明
    FTPClient 中 FTPClient.changeWorkingDirectory(filePath) 代码一直返回 false
    Eclipse 中 Debug 时鼠标悬停无法查看变量值
    Innodb ,MyISAM
    tomcat jetty
  • 原文地址:https://www.cnblogs.com/68xi/p/8590600.html
Copyright © 2011-2022 走看看