zoukankan      html  css  js  c++  java
  • 【505】NLP实战系列(二)—— keras 中的 Embedding 层

    参考:嵌入层 Embedding

    参考:Python3 assert(断言)


    1. Embedding 层语法

    keras.layers.Embedding(input_dim, output_dim, embeddings_initializer='uniform', embeddings_regularizer=None, activity_regularizer=None, embeddings_constraint=None, mask_zero=False, input_length=None)
    

      将正整数(索引值)转换为固定尺寸的稠密向量。 例如: [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]。该层只能用作模型中的第一层。

    2. 参数说明

    • input_dim: int > 0。词汇表大小, 即,最大整数 index + 1。
    • output_dim: int >= 0。词向量的维度。
    • embeddings_initializer: embeddings 矩阵的初始化方法 (详见 initializers)。
    • embeddings_regularizer: embeddings matrix 的正则化方法 (详见 regularizer)。
    • embeddings_constraint: embeddings matrix 的约束函数 (详见 constraints)。
    • mask_zero: 是否把 0 看作为一个应该被遮蔽的特殊的 "padding" 值。 这对于可变长的 循环神经网络层 十分有用。 如果设定为 True,那么接下来的所有层都必须支持 masking,否则就会抛出异常。 如果 mask_zero 为 True,作为结果,索引 0 就不能被用于词汇表中 (input_dim 应该与 vocabulary + 1 大小相同)。
    • input_length: 输入序列的长度,当它是固定的时。 如果你需要连接 Flatten 和 Dense 层,则这个参数是必须的 (没有它,dense 层的输出尺寸就无法计算)。

      标记红色的是比较重要的参数,一般来说是需要具体赋值的。

    3. 输入尺寸

      尺寸为 (batch_size, sequence_length) 的 2D 张量。

    • batch_size:每个批次的字符串数量
    • sequence_length:字符串长度,多了截断,少了补0

    4. 输出尺寸

      尺寸为 (batch_size, sequence_length, output_dim) 的 3D 张量。

    • batch_size:每个批次的字符串数量
    • sequence_length:字符串长度,多了截断,少了补0
    • output_dim:稠密矩阵维度

    5. 举例

    model = Sequential()
    model.add(Embedding(1000, 64, input_length=10))
    # 模型将输入一个大小为 (batch, input_length) 的整数矩阵。
    # 输入中最大的整数(即词索引)不应该大于 999 (词汇表大小)
    # 64 表示稠密矩阵的维度
    # input_length=10 表示字符串长度
    # 现在 model.output_shape == (None, 10, 64),其中 None 是 batch 的维度。
    
    input_array = np.random.randint(1000, size=(32, 10))
    # 新建一个输入数据
    # 32 表示字符串数量
    # 10 表示字符串长度
    # 整体都是一些小于1000的整数表示,每一个数字对应于一个单词
    
    model.compile('rmsprop', 'mse')
    output_array = model.predict(input_array)
    assert output_array.shape == (32, 10, 64)
    # 没有提示错误,说明维度输出是正确的
    # 32 表示字符串数量
    # 10 表示字符串长度
    # 64 表示稠密绝阵的维度
    

      

  • 相关阅读:
    热更新--动态加载framework
    封装framework注意点
    zip压缩和解压缩
    iOS 网络请求数据缓存
    tomcat服务器访问网址组成
    iOS--支付宝环境集成
    线程10--NSOperation的基本操作
    线程9--NSOperation
    线程8--GCD常见用法
    线程7--GCD的基本使用
  • 原文地址:https://www.cnblogs.com/alex-bn-lee/p/14191176.html
Copyright © 2011-2022 走看看