注:代码是网上下载的,但是找不到原始出处了,侵权则删
先写出visual类:
tf1.x:
class TF_visualizer(object): def __init__(self, dimension, vecs_file, metadata_file, output_path): self.dimension = dimension self.vecs_file = vecs_file self.metadata_file = metadata_file self.output_path = output_path self.vecs = [] with open(self.vecs_file, 'r') as vecs: #with open(self.vecs_file, 'rb') as vecs: for i, line in enumerate(vecs): if line != '': self.vecs.append(line) def visualize(self): # adding into projector config = projector.ProjectorConfig() placeholder = np.zeros((len(self.vecs), self.dimension)) for i, line in enumerate( self.vecs ): placeholder[i] = np.fromstring(line, sep=',') #for i,line in enumerate(self.vecs): # placeholder[i] = np.fromstring(line) embedding_var = tf.Variable(placeholder, trainable=False, name='amazon') embed = config.embeddings.add() embed.tensor_name = embedding_var.name embed.metadata_path = self.metadata_file # define the model without training sess = tf.InteractiveSession() tf.global_variables_initializer().run() saver = tf.train.Saver() saver.save(sess, os.path.join(self.output_path, 'w2x_metadata.ckpt')) writer = tf.summary.FileWriter(self.output_path, sess.graph) projector.visualize_embeddings(writer, config) sess.close() print('Run `tensorboard --logdir={0}` to run visualize result on tensorboard'.format(self.output_path))
tf2.x:
class TF_visualizer(object): def __init__(self, dimension, vecs_file, metadata_file, output_path): self.dimension = dimension self.vecs_file = vecs_file self.metadata_file = metadata_file self.output_path = output_path self.vecs = [] with open(self.vecs_file, 'r') as vecs: #with open(self.vecs_file, 'rb') as vecs: for i, line in enumerate(vecs): if line != '': self.vecs.append(line) def visualize(self): # adding into projector config = projector.ProjectorConfig() placeholder = np.zeros((len(self.vecs), self.dimension)) for i, line in enumerate( self.vecs ): placeholder[i] = np.fromstring(line, sep=',') #for i,line in enumerate(self.vecs): # placeholder[i] = np.fromstring(line) embedding_var = tf.Variable(placeholder, trainable=False, name='kmeans') embed = config.embeddings.add() embed.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE" embed.metadata_path = self.metadata_file checkpoint = tf.train.Checkpoint(embedding=embedding_var) checkpoint.save(os.path.join(self.output_path, 'w2x_metadata.ckpt')) projector.visualize_embeddings(self.output_path,config) #sess.run(tf.compat.v1.global_variables_initializer()) #saver = tf.compat.v1.train.Saver() #saver.save(sess, os.path.join(self.output_path, 'w2x_metadata.ckpt'),STEP) #train_writer = tf.summary.create_file_writer('./logs/1/train') #writer = tf.summary.FileWriter(self.output_path, sess.graph) #writer = tf.summary.create_file_writer(self.output_path) #with writer.as_default(): # tf.summary.projector(tensor=embedding_var,labels=self.vecs,step=STEP,name='desired name') #writer.write(sess.graph) #projector.visualize_embeddings(writer, config) #sess.close() print('Run `tensorboard --logdir={0}` to run visualize result on tensorboard'.format(self.output_path))
然后调用类:
output = '/home/xx' # create a new tensor board visualizer visualizer = TF_visualizer(dimension = 768, vecs_file = os.path.join(output, 'amazon_vec.tsv'), #vecs_file = os.path.join(output, 'mnist_10k_784d_tensors.bytes'), metadata_file = os.path.join(output, 'amazon.tsv'), output_path = output) visualizer.visualize()
其中,amazon_vec.tsv中存放向量(包括词向量,句子向量...),amazon.tsv中存放原始数据,格式为id,label,title,id和title可以随意定义,label则为对应向量的标识,两个文件是 一一对应的(即amazon_vec中的第一行数据对应amazon中第一行数据)
注:amazon.tsv举例:
0,快乐,0
1,幸福,1
2,哪里,2
3,剪子,3
4,鼠标,4
amazon_vec.tsv举例:
0.21729235351085663,-0.21714180707931519,0.10137219727039337,0.530093789100647,-0.16228507459163666
1.43824303150177,0.40661126375198364,-1.3043369054794312,0.4775696396827698,-0.21097205579280853
-0.3607819080352783,0.9494154453277588,1.102367877960205,1.2270256280899048,0.4637971818447113
0.7370161414146423,-1.6456717252731323,-2.1842262744903564,2.0185391902923584,1.6656044721603394
1.03994882106781,-1.7641232013702393,1.042765736579895,2.6722264289855957,1.6226638555526733
最后,命令行输入
tensorboard --logdir=/home/xx
在浏览器输入http://xx-desktop:6006即可看到可视化的数据(6006是默认端口)