zoukankan      html  css  js  c++  java
  • tensorflow加载embedding模型进行可视化

    1.功能

    采用python的gensim模块训练的word2vec模型,然后采用tensorflow读取模型可视化embedding向量

    ps:采用C++版本训练的w2v模型,python的gensim模块读不了。

    2.python训练word2vec模型代码

    import multiprocessing
    
    from gensim.models.word2vec import Word2Vec, LineSentence
    
    print('开始训练')
    train_file = "/tmp/train_data"
    
    model = Word2Vec(LineSentence(train_file), size=128, workers=multiprocessing.cpu_count(), iter=10)
    print('结束')
    model.init_sims(replace=True)
    model.save('/tmp/emb.bin')

    3.tensorflow读取模型可视化

    import numpy as np
    import tensorflow as tf
    import os
    from gensim.models.word2vec import Word2Vec
    from tensorflow.contrib.tensorboard.plugins import projector
    
    log_dir = '/tmp/embedding_log'
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    
    
    # load model
    model_file = '/tmp/emb.bin'
    word2vec = Word2Vec.load(model_file)
    
    # create a list of vectors
    embedding = np.empty((len(word2vec.vocab.keys()), word2vec.vector_size), dtype=np.float32)
    for i, word in enumerate(word2vec.vocab.keys()):
        embedding[i] = word2vec[word]
    
    # setup a TensorFlow session
    tf.reset_default_graph()
    sess = tf.InteractiveSession()
    X = tf.Variable([0.0], name='embedding')
    place = tf.placeholder(tf.float32, shape=embedding.shape)
    set_x = tf.assign(X, place, validate_shape=False)
    sess.run(tf.global_variables_initializer())
    sess.run(set_x, feed_dict={place: embedding})
    
    # write labels
    with open(os.path.join(log_dir, 'metadata.tsv'), 'w') as f:
        for word in word2vec.vocab.keys():
            f.write(word + '
    ')
    
    # create a TensorFlow summary writer
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
    config = projector.ProjectorConfig()
    embedding_conf = config.embeddings.add()
    embedding_conf.tensor_name = 'embedding:0'
    embedding_conf.metadata_path = os.path.join(log_dir, 'metadata.tsv')
    projector.visualize_embeddings(summary_writer, config)
    
    # save the model
    saver = tf.train.Saver()
    saver.save(sess, os.path.join(log_dir, "model.ckpt"))
    
    print("完成!")
  • 相关阅读:
    Object-C,NSArraySortTest,数组排序3种方式
    Object-C,NSArraySortTest,数组排序3种方式
    Object-C,数组NSArray
    Object-C,数组NSArray
    Zookeeper入门-Linux环境下异常ConnectionLossException解决
    Zookeeper入门-Linux环境下异常ConnectionLossException解决
    POJ 2533 Longest Ordered Subsequence
    HDU 1087 Super Jumping! Jumping! Jumping!
    ZJU 2676 Network Wars
    ZJU 2671 Cryptography
  • 原文地址:https://www.cnblogs.com/aijianiula/p/10221970.html
Copyright © 2011-2022 走看看