zoukankan      html  css  js  c++  java
  • tensorflow2.0——自定义全连接层实现并保存

    import tensorflow as tf
    
    
    def preprocess(x, y):
        x = tf.cast(x, dtype=tf.float32) / 255 - 0.5
        y = tf.cast(y, dtype=tf.int32)
        return x, y
    
    
    batchsz = 128
    #   [50k,32,32,3],[50k,1]
    (x, y), (x_val, y_val) = tf.keras.datasets.cifar10.load_data()
    y = tf.one_hot(y, depth=10)  # [50k,10]
    y_val = tf.one_hot(y_val, depth=10)
    print(x.shape, y.shape)
    y = tf.squeeze(y)  # 去掉为1 的维度
    y_val = tf.squeeze(y_val)
    print('squeeze后:')
    print(x.shape, y.shape, x.min(), x.max())
    
    train_db = tf.data.Dataset.from_tensor_slices((x, y))
    train_db = train_db.map(preprocess).shuffle(1000).batch(batchsz)
    test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    test_db = test_db.map(preprocess).batch(batchsz)
    
    sample = next(iter((train_db)))  # 测试下数据集shape是否符合要求  batch (128, 32, 32, 3) (128, 10)
    print('batch:', sample[0].shape, sample[1].shape)
    
    
    #   自定义层
    #   代替标准的tf.keras.layers.Dense()
    class MyDense(tf.keras.layers.Layer):
        def __init__(self, inp_dim, oup_dim):  # 参数为输入的维度和输出维度
            super(MyDense, self).__init__()
            self.kernel = self.add_variable('w', [inp_dim, oup_dim])
            # self.bias = self.add_variable('b',[oup_dim])
    
        def call(self, inputs, training=None):  # 参数为数据
            x = inputs @ self.kernel
            return x
    
    #   自定义网络
    class MyNetwork(tf.keras.Model):
        def __init__(self):
            super(MyNetwork, self).__init__()
            self.fc1 = MyDense(32 * 32 * 3, 256)
            self.fc2 = MyDense(256, 256)
            self.fc3 = MyDense(256, 256)
            self.fc4 = MyDense(256, 32)
            self.fc5 = MyDense(32, 10)
    
        def call(self, inputs, training=None, mask=None):
            '''
            :param inputs:[b,32,32,3]
            :param training:
            :param mask:
            :return:
            '''
            #   [b,32,32,3] -> [b,32*32*3]
            x = tf.reshape(inputs,[-1,32*32*3])
            #   [b,32*32*3] -> [b,256]
            x = self.fc1(x)
            x = tf.nn.relu(x)
            #   [b,256] -> [b,128]
            x = self.fc2(x)
            x = tf.nn.relu(x)
            #   [b,128] -> [b,64]
            x = self.fc3(x)
            x = tf.nn.relu(x)
            #   [b,64] -> [b,32]
            x = self.fc4(x)
            x = tf.nn.relu(x)
            #   [b,32] -> [b,10]
            x = self.fc5(x)
            #   最后一层不需要激活函数
            return x
    
    network = MyNetwork()
    network.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3),
                    loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])
    network.fit(train_db,epochs=13,validation_data=test_db,validation_freq=1)
    
    network.evaluate(test_db)
    network.save_weights('./save_w_model/test1')
    
    #   加载仅有参数的model
    network2 = MyNetwork()
    network2.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3),
                    loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])
    network2.load_weights('./save_w_model/test1')
    print('加载仅有参数的模型')
    network2.evaluate(test_db)
  • 相关阅读:
    WSP部署错误—SharePoint管理框架中的对象“SPSolutionLanguagePack Name=0”依赖其他不存在的对象
    Elevate Permissions To Modify User Profile
    Error with Stsadm CommandObject reference not set to an instance of an object
    ASP.NET MVC3添加Controller时没有Scaffolding options
    测试使用Windows Live Writer写日志
    配置TFS 2010出现错误—SQL Server 登录的安全标识符(SID)与某个指定的域或工作组帐户冲突
    使用ADO.NET DbContext Generator出现错误—Unable to locate file
    CSS
    HTML DIV标签
    数据库
  • 原文地址:https://www.cnblogs.com/cxhzy/p/13697268.html
Copyright © 2011-2022 走看看