zoukankan      html  css  js  c++  java
  • tensorflow加载embedding模型进行可视化

    1.功能

    采用python的gensim模块训练的word2vec模型,然后采用tensorflow读取模型可视化embedding向量

    ps:采用C++版本训练的w2v模型,python的gensim模块读不了。

    2.python训练word2vec模型代码

    import multiprocessing
    
    from gensim.models.word2vec import Word2Vec, LineSentence
    
    print('开始训练')
    train_file = "/tmp/train_data"
    
    model = Word2Vec(LineSentence(train_file), size=128, workers=multiprocessing.cpu_count(), iter=10)
    print('结束')
    model.init_sims(replace=True)
    model.save('/tmp/emb.bin')

    3.tensorflow读取模型可视化

    import numpy as np
    import tensorflow as tf
    import os
    from gensim.models.word2vec import Word2Vec
    from tensorflow.contrib.tensorboard.plugins import projector
    
    log_dir = '/tmp/embedding_log'
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    
    
    # load model
    model_file = '/tmp/emb.bin'
    word2vec = Word2Vec.load(model_file)
    
    # create a list of vectors
    embedding = np.empty((len(word2vec.vocab.keys()), word2vec.vector_size), dtype=np.float32)
    for i, word in enumerate(word2vec.vocab.keys()):
        embedding[i] = word2vec[word]
    
    # setup a TensorFlow session
    tf.reset_default_graph()
    sess = tf.InteractiveSession()
    X = tf.Variable([0.0], name='embedding')
    place = tf.placeholder(tf.float32, shape=embedding.shape)
    set_x = tf.assign(X, place, validate_shape=False)
    sess.run(tf.global_variables_initializer())
    sess.run(set_x, feed_dict={place: embedding})
    
    # write labels
    with open(os.path.join(log_dir, 'metadata.tsv'), 'w') as f:
        for word in word2vec.vocab.keys():
            f.write(word + '
    ')
    
    # create a TensorFlow summary writer
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
    config = projector.ProjectorConfig()
    embedding_conf = config.embeddings.add()
    embedding_conf.tensor_name = 'embedding:0'
    embedding_conf.metadata_path = os.path.join(log_dir, 'metadata.tsv')
    projector.visualize_embeddings(summary_writer, config)
    
    # save the model
    saver = tf.train.Saver()
    saver.save(sess, os.path.join(log_dir, "model.ckpt"))
    
    print("完成!")
  • 相关阅读:
    js字符串转数组,转对象方法
    react执行yarn eject后配置antd的按需加载
    DOM对象与jquery对象区别
    vscode使用git管理代码
    使用vscode编辑器编辑CPU100%使用率问题
    Java 多态
    1,随机生成一个500m的txt,填充内容为小写的26个字母。生成后,查找abc字符,打印出其数量和位置(越快越好)
    bat 文件
    word2Html
    生成压缩文件
  • 原文地址:https://www.cnblogs.com/aijianiula/p/10221970.html
Copyright © 2011-2022 走看看