zoukankan      html  css  js  c++  java
  • Embedding

    Embedding 就是字典映射,把一个类别映射到一个向量上,方便学习特征。比如对于特征gender有取值有 male,female,创建一个矩阵2*2的矩阵:

    [[1,2],
    [3,4]]
    

    把 male 映射到第一行,得到[1,2],female 映射到二行得到[3,4];它与LabelEncoder区别是LabelEncoder 是映射到一个Integer,embedding是映射到一个向量上。

    OneHot 如果某个特征很稀疏,比如有4096个不同的值,那么这一列衍生出的数据有(n_rows, 4096) :
    如果接个Dense,那么Dense一个单元的权重就得是(4096, batch_size),神经网络在训练中需要频繁传递权重,减慢了训练速度,如果把这n_rows行数据映射到数据(?, 4)的数组上,Dense的权重就是. (4, batch_size)。

    用TF 实现Embedding:

    1. 定义数据
    import pandas as pd
    df = pd.DataFrame(data=[0, 1, 0, 1, 1], columns=['gender'])  # 特征数据
    df[:2]
    

    输出:

       gender
    0       0
    1       1
    
    1. 初始化embedding参数
    import tensorflow as tf
    embedding_params = tf.constant([
                [1, 2],
                [3, 4]
            ])
    

    因为要给特征的每一个值都映射到一个不同的向量,所以 embedding_params 的行数就是 gender 不同值的个数,向量长度推荐>=3;
    实际使用可以随机生成embedding参数:

    embedding_params = tf.random.truncated_normal([2, 2])  # 正态分布随机数
    
    1. 映射成数字
    tf.nn.embedding_lookup(embedding_params, df['gender'].values)
    

    TF也有tf.keras.layers.Embedding用来在网络中搭建Embedding层,主要参数含义:

    • input_dim 词袋的大小,也就是不同值的个数,用来生产embedding_params
    • output_dim 输出维度数,也就是映射成向量的长度

    注意:Embedding输出的是 (?, 1, ouput_dim) 而不是 (?, output_dim) ,会加一个维度, 看个例子:

    from tensorflow.keras.layers import Input
    import tensorflow as tf
    import pandas as pd
    
    df = pd.DataFrame(data=[0, 1, 0, 1, 1], columns=['gender'])
    input_gender = Input(shape=(1, ), name="gender")
    layer_embedding = tf.keras.layers.Embedding(input_dim=2, output_dim=4)(input_gender)
    layer_embedding
    

    输出:

    <tf.Tensor 'embedding_2/Identity:0' shape=(None, 1, 4) dtype=float32>
    

    可以使用tf.squeeze 去掉这一维:

    tensor_squeeze = tf.squeeze(layer_embedding, axis=1)
    tensor_squeeze
    

    输出:

    <tf.Tensor 'Squeeze:0' shape=(None, 4) dtype=float32>
    

    对应的还提供了tf.expand_dims来增加维度:

    tf.expand_dims(tensor_squeeze, axis=1)
    

    输出:

    <tf.Tensor 'ExpandDims:0' shape=(None, 1, 4) dtype=float32>
    
  • 相关阅读:
    ubuntu 制做samba
    《Programming WPF》翻译 第4章 前言
    《Programming WPF》翻译 第4章 3.绑定到数据列表
    《Programming WPF》翻译 第4章 4.数据源
    《Programming WPF》翻译 第5章 6.触发器
    《Programming WPF》翻译 第4章 2.数据绑定
    《Programming WPF》翻译 第4章 1.不使用数据绑定
    《Programming WPF》翻译 第5章 7.控件模板
    《Programming WPF》翻译 第5章 8.我们进行到哪里了?
    《Programming WPF》翻译 第5章 5.数据模板和样式
  • 原文地址:https://www.cnblogs.com/oaks/p/14047119.html
Copyright © 2011-2022 走看看