zoukankan      html  css  js  c++  java
  • tensorflow 在加载大型的embedding模型参数时,会遇到cannot be larger than 2GB

          这种问题是,对于每一个变量 variable 由于是基于protobuf存在这大小限制(2G),这个时候,我们需要将embedding拆开,拆分成N等分,来使得每一个

    variable都在2G以下; 

      

     1 # !/usr/bin/env/python
     2 # coding=utf-8
     3 import tensorflow as tf
     4 import numpy as np
     5 
     6 input_ids = tf.placeholder(dtype=tf.int32, shape=[None,None])
     7 
     8 num_shards = 3
     9 weights = []
    10 weights_shape = np.arange(27).reshape(9, 3)
    11 # assert weights_shape[0] % num_shards == 0
    12 num_shards_len = (weights_shape.shape[0]) / num_shards
    13 assert  (weights_shape.shape[0]) % num_shards ==0
    14 begin_ = 0
    15 ends_ = num_shards_len
    16 for i in range(0, num_shards):
    17     if (i + 1) * num_shards_len < weights_shape.shape[0]:
    18         begin_ = i * num_shards_len
    19         if i + 1 == num_shards:
    20             ends_ = weights_shape.shape[0]
    21         else:
    22             ends_ = (i + 1) * num_shards_len
    23     else:
    24         begin_ = i * num_shards_len
    25         ends_ = weights_shape.shape[0]
    26     weights_i = tf.get_variable("words-%02d" % i,
    27                                 initializer=tf.constant(weights_shape[begin_: ends_, ]))
    28     weights.append(weights_i)
    29 
    30 input_embedding = tf.nn.embedding_lookup(weights, input_ids,partition_strategy="div")
    31 
    32 sess = tf.InteractiveSession()
    33 sess.run(tf.global_variables_initializer())
    34 print(sess.run(weights))
    35 
    36 print(sess.run(input_embedding, feed_dict={input_ids: [[1, 2], [3, 0], [8, 2], [5, 1]]}))

     结果为:

        

    [array([[0, 1, 2],
           [3, 4, 5],
           [6, 7, 8]]), array([[ 9, 10, 11],
           [12, 13, 14],
           [15, 16, 17]]), array([[18, 19, 20],
           [21, 22, 23],
           [24, 25, 26]])]
    [[[ 3  4  5]
      [ 6  7  8]]
    
     [[ 9 10 11]
      [ 0  1  2]]
    
     [[24 25 26]
      [ 6  7  8]]
    
     [[15 16 17]
      [ 3  4  5]]]
  • 相关阅读:
    表单标签
    无序列表有序列表
    跳转锚点
    HTML标签01
    HTML基本结构和属性
    python爬虫学习笔记(二十三)-Scrapy框架 CrawlSpider
    python爬虫学习笔记(二十二)-Scrapy框架 案例实现
    python爬虫学习笔记(二十一)-Scrapy框架 setting
    python爬虫学习笔记(二十)-Scrapy框架 Pipeline
    python爬虫学习笔记(十九)-Scrapy 数据的保存
  • 原文地址:https://www.cnblogs.com/gongxijun/p/9995960.html
Copyright © 2011-2022 走看看