数据增强可以帮助扩展数据集。对图像的增强,就是对图像的简单形变,用来应对因拍照角度不同而引起的图片变形。
数据增强函数
image_gen_train = tf.keras.preprocessing.image.ImageDataGenerator(
# 调整输入特征大小,每个输入特征将乘以该参数
rescale=1.0 / 255, # 归一化
# 图片将在[-45°, 45°]范围内做随机旋转
rotation_range=45,
# 图片将在[-0.15, 0.15]范围内做随机左右偏移,大小保持不变
width_shift_range=0.15,
# 图片将在[-0.15, 0.15]范围内做随机上下偏移,大小保持不变
height_shift_range=0.15,
# 是否做水平翻转操作
horizontal_flip=False,
# 图片将做[0.75, 1.25]范围内做随机缩放,大小保持不变
zoom_range=0.25
)
# x_train 需要是4维数据
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
image_gen_train.fit(x_train)
代码:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 加载输入特征和标签
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 归一化处理,减小计算量,方便神经网络吸收
x_train, x_test = x_train/255.0, x_test/255.0
# 数据集增强
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
image_gen_train = ImageDataGenerator(
rotation_range=45, # 随机旋转45°
width_shift_range=0.15, # 宽度偏移
height_shift_range=0.15, # 高度偏移
horizontal_flip=False, # 不水平翻转
zoom_range=0.5 # 随机缩放
)
image_gen_train.fit(x_train)
# 声明网络结构
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax")
])
# 配置训练方法
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[tf.keras.metrics.sparse_categorical_accuracy])
# 执行训练过程
model.fit(x_train, y_train,
batch_size=32, epochs=5,
validation_data=(x_test, y_test),
validation_freq=1)
# 打印网络结构和参数
model.summary()