zoukankan      html  css  js  c++  java
  • TensorFlow2.0之数据标准化

    import tensorflow as tf
    import tensorflow.keras as keras
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from sklearn.preprocessing import StandardScaler
    
    #导入数据
    fashion_mnist = keras.datasets.fashion_mnist
    (X_train_all, y_train_all), (X_test, y_test) = fashion_mnist.load_data()
    X_valid, X_train = X_train_all[:5000], X_train_all[1000:]
    y_valid, y_train = y_train_all[:5000], y_train_all[1000:]
    
    #归一化
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train.astype(np.float32).reshape(-1,28*28)).reshape(-1,28,28)
    X_valid_scaled = scaler.transform(X_valid.astype(np.float32).reshape(-1,28*28)).reshape(-1, 28, 28)
    
    
    #构建模型
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=[28, 28]),
        keras.layers.Dense(300, activation='sigmoid'),
        keras.layers.Dense(100, activation='sigmoid'),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    history = model.fit(X_train_scaled, y_train, epochs=10)
    print(history.history)
    scaler = StandardScaler()定义一个对象用于标准化
    scaler.fit_transform(X_train.astype(np.float32).reshape(-1,28*28)).reshape(-1,28,28)调用fit_transform函数,该函数由两部人份组成,第一部分是fit函数,这个函数产生了每一列的均值和标准差,存入scarler对象中,然后调用transform函数对矩阵进行标准化。而fit_transform函数是对这两个函数的整合。所以可以看到
    X_valid_scaled = scaler.transform(X_valid.astype(np.float32).reshape(-1,28*28)).reshape(-1, 28, 28)直接调用transform而不是fit_transform,因为scaler中已经有相应的均值和方差了,所以以后只要直接transform就行。因为不管是对验证集还是测试集,都是用的训练集的均值和方差。

    可以通过
    print(scaler.mean_)
    print(scaler.scale_)
    查看训练集中每一列(每一个特征)的均值和标准差。
     
     
  • 相关阅读:
    jQuery——通过Ajax发送数据
    Python爬虫入门教程 71-100 续上篇,python爬虫爬取B站视频
    实战演练:PostgreSQL在线扩容
    直播丨Oracle比特币勒索&数据库大咖讲坛
    使用seaborn绘制强化学习中的图片
    nginx stream模块
    工具用的好下班走的早
    10年大数据平台经验,总结出这份数据建设干货(内含多张架构图)
    nginx 配置4层转发
    详解pytorch中的max方法
  • 原文地址:https://www.cnblogs.com/loubin/p/12579269.html
Copyright © 2011-2022 走看看