今天算是摸鱼了,代码方面学的不多,倒是温习了一遍日语,因为我对这方面感兴趣。。。。下面进入正题,今天只是简单尝试了一下tensorflow中函数式api的使用,其实用它写运行代码跟之前测试代码运行出来的效果是差不多的,因为语句都是一个意思,不过用函数式api写的话个人感觉更好理解一些,而且可以做一些其他操作。
首先看一下如何使用函数式api来实现建立模型并编译,数据集依然是老朋友fashion MNIST:
import tensorflow as tf from tensorflow import keras import matplotlib.pyplot as plt %matplotlib inline fashion_mnist = keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() #归一化 train_images = train_images / 255.0 test_images = test_images / 255.0 #建立输入 input = keras.Input(shape=(28, 28)) #函数调用 x = keras.layers.Flatten()(input) x = keras.layers.Dense(32, activation='relu')(x) x = keras.layers.Dropout(0.5)(x) x = keras.layers.Dense(64, activation='relu')(x) #设置输出 output = keras.layers.Dense(10, activation='softmax')(x) #建立模型 model = keras.Model(inputs=input, outputs=output) model.summary()
可以看到在函数式api中需要自己手动创建输入和输出对象,从注释“函数调用”之后就是之前测试代码过程中的模型创建过程,只不过以前写的时候相当于简写,直接在Sequential内就将Dense,Dropout等参数设置好了,在这里我们添加输入对象建立的是Flatten,红箭头标明的None是因为在设置input时shape里没有特别指定,默认为None。
继续:
#编译 model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy']) history = model.fit(train_images,train_labels,epochs=30,validation_data=(test_images,test_labels))
test_loss, test_acc = model.evaluate(test_images, test_labels) plt.plot(history.epoch, history.history['loss'],'r',label='loss') plt.plot(history.epoch, history.history['val_loss'],'b--',label='val_loss') plt.legend()
事实上看过函数式api的代码都感觉很熟悉,前几天测试时没少建立模型,大体上都是这样的构造。如果说函数式api只有这样的话,哪里还用单拎出来呢,所以它还有不同的地方。
我们可以设置两个输入对象,利用函数式api来进行两个图片是否一致的判断:
前面调用什么的都是一样的,包括数据集,因此直接从不一样的地方开始贴,也就是建立input对象的时候:
#建立输入1 input1 = keras.Input(shape=(28, 28)) #建立输入2 input2 = keras.Input(shape=(28, 28)) #函数调用 x1 = keras.layers.Flatten()(input1) x2 = keras.layers.Flatten()(input2) #将这两个对象合并成一个,进行对两个图片是否一样的判断(逻辑回归) x = keras.layers.concatenate([x1, x2]) x = keras.layers.Dense(32, activation='relu')(x) #设置输出,注意激活函数使用的是sigmoid output = keras.layers.Dense(1, activation='sigmoid')(x) #建立模型,注意设置了两个input对象,因此inputs使用数组形式 model = keras.Model(inputs=[input1, input2], outputs=output) model.summary()
可以看到这个新的模型出现了分歧,这也是函数式api的一个应用。
以上就是今天学习的tensorflow相关内容,今天摸鱼了不少时间,就当给自己放了个假= =。