一、核心代码 word2vec.java
1 package com.ansj.vec; 2 3 import java.io.*; 4 import java.lang.reflect.Array; 5 import java.util.ArrayList; 6 import java.util.Arrays; 7 import java.util.Collections; 8 import java.util.HashMap; 9 import java.util.List; 10 import java.util.Map; 11 import java.util.Map.Entry; 12 import java.util.Set; 13 import java.util.TreeSet; 14 15 import com.ansj.vec.domain.WordEntry; 16 import com.ansj.vec.util.WordKmeans; 17 import com.ansj.vec.util.WordKmeans.Classes; 18 19 public class Word2VEC { 20 21 public static void main(String[] args) throws IOException { 22 23 //Learn learn = new Learn(); 24 //learn.learnFile(new File("C:\Users\le\Desktop\0328-事件相关法律的算法进展\Result_Country.txt")); 25 //learn.saveModel(new File("C:\Users\le\Desktop\0328-事件相关法律的算法进展\javaSkip1")); 26 27 Word2VEC vec = new Word2VEC(); 28 vec.loadJavaModel("C:\Users\le\Desktop\0328-事件相关法律的算法进展\javaSkip1"); 29 System.out.println("中国" + " " +Arrays.toString(vec.getWordVector("中国"))); 30 System.out.println("何润东" + " " +Arrays.toString(vec.getWordVector("何润东"))); 31 System.out.println("足球" + " " + Arrays.toString(vec.getWordVector("足球"))); 32 33 String str = "中国"; 34 System.out.println(vec.distance(str)); 35 WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), 50, 10); 36 Classes[] explain = wordKmeans.explain(); 37 for (int i = 0; i < explain.length; i++) { 38 System.out.println("--------" + i + "---------"); 39 System.out.println(explain[i].getTop(10)); 40 } 41 } 42 43 private HashMap<String, float[]> wordMap = new HashMap<String, float[]>(); 44 45 private int words; 46 private int size; 47 private int topNSize = 40; 48 49 /** 50 * 鍔犺浇妯″瀷 51 * 52 * @param path 53 * 妯″瀷鐨勮矾寰� 54 * @throws IOException 55 */ 56 public void loadGoogleModel(String path) throws IOException { 57 DataInputStream dis = null; 58 BufferedInputStream bis = null; 59 double len = 0; 60 float vector = 0; 61 try { 62 bis = new BufferedInputStream(new FileInputStream(path)); 63 dis = new DataInputStream(bis); 64 // //璇诲彇璇嶆暟 65 words = Integer.parseInt(readString(dis)); 66 // //澶у皬 67 size = Integer.parseInt(readString(dis)); 68 String word; 69 float[] vectors = null; 70 for (int i = 0; i < words; i++) { 71 word = readString(dis); 72 vectors = new float[size]; 73 len = 0; 74 for (int j = 0; j < size; j++) { 75 vector = readFloat(dis); 76 len += vector * vector; 77 vectors[j] = (float) vector; 78 } 79 len = Math.sqrt(len); 80 81 for (int j = 0; j < size; j++) { 82 vectors[j] /= len; 83 } 84 85 wordMap.put(word, vectors); 86 dis.read(); 87 } 88 } finally { 89 bis.close(); 90 dis.close(); 91 } 92 } 93 94 /** 95 * 鍔犺浇妯″瀷 96 * 97 * @param path 98 * 妯″瀷鐨勮矾寰� 99 * @throws IOException 100 */ 101 public void loadJavaModel(String path) throws IOException { 102 try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)))) { 103 words = dis.readInt(); 104 size = dis.readInt(); 105 106 float vector = 0; 107 108 String key = null; 109 float[] value = null; 110 for (int i = 0; i < words; i++) { 111 double len = 0; 112 key = dis.readUTF(); 113 value = new float[size]; 114 for (int j = 0; j < size; j++) { 115 vector = dis.readFloat(); 116 len += vector * vector; 117 value[j] = vector; 118 } 119 120 len = Math.sqrt(len); 121 122 for (int j = 0; j < size; j++) { 123 value[j] /= len; 124 } 125 wordMap.put(key, value); 126 } 127 128 } 129 } 130 131 private static final int MAX_SIZE = 50; 132 133 /** 134 * 杩戜箟璇� 135 * 136 * @return 137 */ 138 public TreeSet<WordEntry> analogy(String word0, String word1, String word2) { 139 float[] wv0 = getWordVector(word0); 140 float[] wv1 = getWordVector(word1); 141 float[] wv2 = getWordVector(word2); 142 143 if (wv1 == null || wv2 == null || wv0 == null) { 144 return null; 145 } 146 float[] wordVector = new float[size]; 147 for (int i = 0; i < size; i++) { 148 wordVector[i] = wv1[i] - wv0[i] + wv2[i]; 149 } 150 float[] tempVector; 151 String name; 152 List<WordEntry> wordEntrys = new ArrayList<WordEntry>(topNSize); 153 for (Entry<String, float[]> entry : wordMap.entrySet()) { 154 name = entry.getKey(); 155 if (name.equals(word0) || name.equals(word1) || name.equals(word2)) { 156 continue; 157 } 158 float dist = 0; 159 tempVector = entry.getValue(); 160 for (int i = 0; i < wordVector.length; i++) { 161 dist += wordVector[i] * tempVector[i]; 162 } 163 insertTopN(name, dist, wordEntrys); 164 } 165 return new TreeSet<WordEntry>(wordEntrys); 166 } 167 168 private void insertTopN(String name, float score, List<WordEntry> wordsEntrys) { 169 // TODO Auto-generated method stub 170 if (wordsEntrys.size() < topNSize) { 171 wordsEntrys.add(new WordEntry(name, score)); 172 return; 173 } 174 float min = Float.MAX_VALUE; 175 int minOffe = 0; 176 for (int i = 0; i < topNSize; i++) { 177 WordEntry wordEntry = wordsEntrys.get(i); 178 if (min > wordEntry.score) { 179 min = wordEntry.score; 180 minOffe = i; 181 } 182 } 183 184 if (score > min) { 185 wordsEntrys.set(minOffe, new WordEntry(name, score)); 186 } 187 188 } 189 190 public Set<WordEntry> distance(String queryWord) { 191 192 float[] center = wordMap.get(queryWord); 193 if (center == null) { 194 return Collections.emptySet(); 195 } 196 197 int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize; 198 TreeSet<WordEntry> result = new TreeSet<WordEntry>(); 199 200 double min = Float.MIN_VALUE; 201 for (Map.Entry<String, float[]> entry : wordMap.entrySet()) { 202 float[] vector = entry.getValue(); 203 float dist = 0; 204 for (int i = 0; i < vector.length; i++) { 205 dist += center[i] * vector[i]; 206 } 207 208 if (dist > min) { 209 result.add(new WordEntry(entry.getKey(), dist)); 210 if (resultSize < result.size()) { 211 result.pollLast(); 212 } 213 min = result.last().score; 214 } 215 } 216 result.pollFirst(); 217 218 return result; 219 } 220 221 public Set<WordEntry> distance(List<String> words) { 222 223 float[] center = null; 224 for (String word : words) { 225 center = sum(center, wordMap.get(word)); 226 } 227 228 if (center == null) { 229 return Collections.emptySet(); 230 } 231 232 int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize; 233 TreeSet<WordEntry> result = new TreeSet<WordEntry>(); 234 235 double min = Float.MIN_VALUE; 236 for (Map.Entry<String, float[]> entry : wordMap.entrySet()) { 237 float[] vector = entry.getValue(); 238 float dist = 0; 239 for (int i = 0; i < vector.length; i++) { 240 dist += center[i] * vector[i]; 241 } 242 243 if (dist > min) { 244 result.add(new WordEntry(entry.getKey(), dist)); 245 if (resultSize < result.size()) { 246 result.pollLast(); 247 } 248 min = result.last().score; 249 } 250 } 251 result.pollFirst(); 252 253 return result; 254 } 255 256 private float[] sum(float[] center, float[] fs) { 257 // TODO Auto-generated method stub 258 259 if (center == null && fs == null) { 260 return null; 261 } 262 263 if (fs == null) { 264 return center; 265 } 266 267 if (center == null) { 268 return fs; 269 } 270 271 for (int i = 0; i < fs.length; i++) { 272 center[i] += fs[i]; 273 } 274 275 return center; 276 } 277 278 /** 279 * 寰楀埌璇嶅悜閲� 280 * 281 * @param word 282 * @return 283 */ 284 public float[] getWordVector(String word) { 285 return wordMap.get(word); 286 } 287 288 public static float readFloat(InputStream is) throws IOException { 289 byte[] bytes = new byte[4]; 290 is.read(bytes); 291 return getFloat(bytes); 292 } 293 294 /** 295 * 璇诲彇涓�涓猣loat 296 * 297 * @param b 298 * @return 299 */ 300 public static float getFloat(byte[] b) { 301 int accum = 0; 302 accum = accum | (b[0] & 0xff) << 0; 303 accum = accum | (b[1] & 0xff) << 8; 304 accum = accum | (b[2] & 0xff) << 16; 305 accum = accum | (b[3] & 0xff) << 24; 306 return Float.intBitsToFloat(accum); 307 } 308 309 /** 310 * 璇诲彇涓�涓�瓧绗︿覆 311 * 312 * @param dis 313 * @return 314 * @throws IOException 315 */ 316 private static String readString(DataInputStream dis) throws IOException { 317 // TODO Auto-generated method stub 318 byte[] bytes = new byte[MAX_SIZE]; 319 byte b = dis.readByte(); 320 int i = -1; 321 StringBuilder sb = new StringBuilder(); 322 while (b != 32 && b != 10) { 323 i++; 324 bytes[i] = b; 325 b = dis.readByte(); 326 if (i == 49) { 327 sb.append(new String(bytes)); 328 i = -1; 329 bytes = new byte[MAX_SIZE]; 330 } 331 } 332 sb.append(new String(bytes, 0, i + 1)); 333 return sb.toString(); 334 } 335 336 public int getTopNSize() { 337 return topNSize; 338 } 339 340 public void setTopNSize(int topNSize) { 341 this.topNSize = topNSize; 342 } 343 344 public HashMap<String, float[]> getWordMap() { 345 return wordMap; 346 } 347 348 public int getWords() { 349 return words; 350 } 351 352 public int getSize() { 353 return size; 354 } 355 356 }
二、词向量-模型学习代码learn.java
1 package com.ansj.vec; 2 3 import java.io.BufferedOutputStream; 4 import java.io.BufferedReader; 5 import java.io.DataOutputStream; 6 import java.io.File; 7 import java.io.FileInputStream; 8 import java.io.FileNotFoundException; 9 import java.io.FileOutputStream; 10 import java.io.IOException; 11 import java.io.InputStreamReader; 12 import java.util.ArrayList; 13 import java.util.HashMap; 14 import java.util.List; 15 import java.util.Map; 16 import java.util.Map.Entry; 17 18 import com.ansj.vec.util.MapCount; 19 import com.ansj.vec.domain.HiddenNeuron; 20 import com.ansj.vec.domain.Neuron; 21 import com.ansj.vec.domain.WordNeuron; 22 import com.ansj.vec.util.Haffman; 23 24 public class Learn { 25 26 private Map<String, Neuron> wordMap = new HashMap<>(); 27 /** 28 * 训练多少个特征 29 */ 30 private int layerSize = 200; 31 32 /** 33 * 上下文窗口大小 34 */ 35 private int window = 5; 36 37 private double sample = 1e-3; 38 private double alpha = 0.025; 39 private double startingAlpha = alpha; 40 41 public int EXP_TABLE_SIZE = 1000; 42 43 private Boolean isCbow = false; 44 45 private double[] expTable = new double[EXP_TABLE_SIZE]; 46 47 private int trainWordsCount = 0; 48 49 private int MAX_EXP = 6; 50 51 public Learn(Boolean isCbow, Integer layerSize, Integer window, Double alpha, 52 Double sample) { 53 createExpTable(); 54 if (isCbow != null) { 55 this.isCbow = isCbow; 56 } 57 if (layerSize != null) 58 this.layerSize = layerSize; 59 if (window != null) 60 this.window = window; 61 if (alpha != null) 62 this.alpha = alpha; 63 if (sample != null) 64 this.sample = sample; 65 } 66 67 public Learn() { 68 createExpTable(); 69 } 70 71 /** 72 * trainModel 73 * 74 * @throws IOException 75 */ 76 private void trainModel(File file) throws IOException { 77 try (BufferedReader br = new BufferedReader(new InputStreamReader( 78 new FileInputStream(file)))) { 79 String temp = null; 80 long nextRandom = 5; 81 int wordCount = 0; 82 int lastWordCount = 0; 83 int wordCountActual = 0; 84 while ((temp = br.readLine()) != null) { 85 if (wordCount - lastWordCount > 10000) { 86 System.out.println("alpha:" + alpha + " Progress: " 87 + (int) (wordCountActual / (double) (trainWordsCount + 1) * 100) 88 + "%"); 89 wordCountActual += wordCount - lastWordCount; 90 lastWordCount = wordCount; 91 alpha = startingAlpha 92 * (1 - wordCountActual / (double) (trainWordsCount + 1)); 93 if (alpha < startingAlpha * 0.0001) { 94 alpha = startingAlpha * 0.0001; 95 } 96 } 97 String[] strs = temp.split(" "); 98 wordCount += strs.length; 99 List<WordNeuron> sentence = new ArrayList<WordNeuron>(); 100 for (int i = 0; i < strs.length; i++) { 101 Neuron entry = wordMap.get(strs[i]); 102 if (entry == null) { 103 continue; 104 } 105 // The subsampling randomly discards frequent words while keeping the 106 // ranking same 107 if (sample > 0) { 108 double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1) 109 * (sample * trainWordsCount) / entry.freq; 110 nextRandom = nextRandom * 25214903917L + 11; 111 if (ran < (nextRandom & 0xFFFF) / (double) 65536) { 112 continue; 113 } 114 } 115 sentence.add((WordNeuron) entry); 116 } 117 118 for (int index = 0; index < sentence.size(); index++) { 119 nextRandom = nextRandom * 25214903917L + 11; 120 if (isCbow) { 121 cbowGram(index, sentence, (int) nextRandom % window); 122 } else { 123 skipGram(index, sentence, (int) nextRandom % window); 124 } 125 } 126 127 } 128 System.out.println("Vocab size: " + wordMap.size()); 129 System.out.println("Words in train file: " + trainWordsCount); 130 System.out.println("sucess train over!"); 131 } 132 } 133 134 /** 135 * skip gram 模型训练 136 * 137 * @param sentence 138 * @param neu1 139 */ 140 private void skipGram(int index, List<WordNeuron> sentence, int b) { 141 // TODO Auto-generated method stub 142 WordNeuron word = sentence.get(index); 143 int a, c = 0; 144 for (a = b; a < window * 2 + 1 - b; a++) { 145 if (a == window) { 146 continue; 147 } 148 c = index - window + a; 149 if (c < 0 || c >= sentence.size()) { 150 continue; 151 } 152 153 double[] neu1e = new double[layerSize];// 误差项 154 // HIERARCHICAL SOFTMAX 155 List<Neuron> neurons = word.neurons; 156 WordNeuron we = sentence.get(c); 157 for (int i = 0; i < neurons.size(); i++) { 158 HiddenNeuron out = (HiddenNeuron) neurons.get(i); 159 double f = 0; 160 // Propagate hidden -> output 161 for (int j = 0; j < layerSize; j++) { 162 f += we.syn0[j] * out.syn1[j]; 163 } 164 if (f <= -MAX_EXP || f >= MAX_EXP) { 165 continue; 166 } else { 167 f = (f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2); 168 f = expTable[(int) f]; 169 } 170 // 'g' is the gradient multiplied by the learning rate 171 double g = (1 - word.codeArr[i] - f) * alpha; 172 // Propagate errors output -> hidden 173 for (c = 0; c < layerSize; c++) { 174 neu1e[c] += g * out.syn1[c]; 175 } 176 // Learn weights hidden -> output 177 for (c = 0; c < layerSize; c++) { 178 out.syn1[c] += g * we.syn0[c]; 179 } 180 } 181 182 // Learn weights input -> hidden 183 for (int j = 0; j < layerSize; j++) { 184 we.syn0[j] += neu1e[j]; 185 } 186 } 187 188 } 189 190 /** 191 * 词袋模型 192 * 193 * @param index 194 * @param sentence 195 * @param b 196 */ 197 private void cbowGram(int index, List<WordNeuron> sentence, int b) { 198 WordNeuron word = sentence.get(index); 199 int a, c = 0; 200 201 List<Neuron> neurons = word.neurons; 202 double[] neu1e = new double[layerSize];// 误差项 203 double[] neu1 = new double[layerSize];// 误差项 204 WordNeuron last_word; 205 206 for (a = b; a < window * 2 + 1 - b; a++) 207 if (a != window) { 208 c = index - window + a; 209 if (c < 0) 210 continue; 211 if (c >= sentence.size()) 212 continue; 213 last_word = sentence.get(c); 214 if (last_word == null) 215 continue; 216 for (c = 0; c < layerSize; c++) 217 neu1[c] += last_word.syn0[c]; 218 } 219 220 // HIERARCHICAL SOFTMAX 221 for (int d = 0; d < neurons.size(); d++) { 222 HiddenNeuron out = (HiddenNeuron) neurons.get(d); 223 double f = 0; 224 // Propagate hidden -> output 225 for (c = 0; c < layerSize; c++) 226 f += neu1[c] * out.syn1[c]; 227 if (f <= -MAX_EXP) 228 continue; 229 else if (f >= MAX_EXP) 230 continue; 231 else 232 f = expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))]; 233 // 'g' is the gradient multiplied by the learning rate 234 // double g = (1 - word.codeArr[d] - f) * alpha; 235 // double g = f*(1-f)*( word.codeArr[i] - f) * alpha; 236 double g = f * (1 - f) * (word.codeArr[d] - f) * alpha; 237 // 238 for (c = 0; c < layerSize; c++) { 239 neu1e[c] += g * out.syn1[c]; 240 } 241 // Learn weights hidden -> output 242 for (c = 0; c < layerSize; c++) { 243 out.syn1[c] += g * neu1[c]; 244 } 245 } 246 for (a = b; a < window * 2 + 1 - b; a++) { 247 if (a != window) { 248 c = index - window + a; 249 if (c < 0) 250 continue; 251 if (c >= sentence.size()) 252 continue; 253 last_word = sentence.get(c); 254 if (last_word == null) 255 continue; 256 for (c = 0; c < layerSize; c++) 257 last_word.syn0[c] += neu1e[c]; 258 } 259 260 } 261 } 262 263 /** 264 * 统计词频 265 * 266 * @param file 267 * @throws IOException 268 */ 269 private void readVocab(File file) throws IOException { 270 MapCount<String> mc = new MapCount<>(); 271 try (BufferedReader br = new BufferedReader(new InputStreamReader( 272 new FileInputStream(file)))) { 273 String temp = null; 274 while ((temp = br.readLine()) != null) { 275 String[] split = temp.split(" "); 276 trainWordsCount += split.length; 277 for (String string : split) { 278 mc.add(string); 279 } 280 } 281 } 282 for (Entry<String, Integer> element : mc.get().entrySet()) { 283 wordMap.put(element.getKey(), new WordNeuron(element.getKey(), 284 (double) element.getValue() / mc.size(), layerSize)); 285 } 286 } 287 288 /** 289 * 对文本进行预分类 290 * 291 * @param files 292 * @throws IOException 293 * @throws FileNotFoundException 294 */ 295 private void readVocabWithSupervised(File[] files) throws IOException { 296 for (int category = 0; category < files.length; category++) { 297 // 对多个文件学习 298 MapCount<String> mc = new MapCount<>(); 299 try (BufferedReader br = new BufferedReader(new InputStreamReader( 300 new FileInputStream(files[category])))) { 301 String temp = null; 302 while ((temp = br.readLine()) != null) { 303 String[] split = temp.split(" "); 304 trainWordsCount += split.length; 305 for (String string : split) { 306 mc.add(string); 307 } 308 } 309 } 310 for (Entry<String, Integer> element : mc.get().entrySet()) { 311 double tarFreq = (double) element.getValue() / mc.size(); 312 if (wordMap.get(element.getKey()) != null) { 313 double srcFreq = wordMap.get(element.getKey()).freq; 314 if (srcFreq >= tarFreq) { 315 continue; 316 } else { 317 Neuron wordNeuron = wordMap.get(element.getKey()); 318 wordNeuron.category = category; 319 wordNeuron.freq = tarFreq; 320 } 321 } else { 322 wordMap.put(element.getKey(), new WordNeuron(element.getKey(), 323 tarFreq, category, layerSize)); 324 } 325 } 326 } 327 } 328 329 /** 330 * Precompute the exp() table f(x) = x / (x + 1) 331 */ 332 private void createExpTable() { 333 for (int i = 0; i < EXP_TABLE_SIZE; i++) { 334 expTable[i] = Math.exp(((i / (double) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP)); 335 expTable[i] = expTable[i] / (expTable[i] + 1); 336 } 337 } 338 339 /** 340 * 根据文件学习 341 * 342 * @param file 343 * @throws IOException 344 */ 345 public void learnFile(File file) throws IOException { 346 readVocab(file); 347 new Haffman(layerSize).make(wordMap.values()); 348 349 // 查找每个神经元 350 for (Neuron neuron : wordMap.values()) { 351 ((WordNeuron) neuron).makeNeurons(); 352 } 353 354 trainModel(file); 355 } 356 357 /** 358 * 根据预分类的文件学习 359 * 360 * @param summaryFile 361 * 合并文件 362 * @param classifiedFiles 363 * 分类文件 364 * @throws IOException 365 */ 366 public void learnFile(File summaryFile, File[] classifiedFiles) 367 throws IOException { 368 readVocabWithSupervised(classifiedFiles); 369 new Haffman(layerSize).make(wordMap.values()); 370 // 查找每个神经元 371 for (Neuron neuron : wordMap.values()) { 372 ((WordNeuron) neuron).makeNeurons(); 373 } 374 trainModel(summaryFile); 375 } 376 377 /** 378 * 保存模型 379 */ 380 public void saveModel(File file) { 381 // TODO Auto-generated method stub 382 383 try (DataOutputStream dataOutputStream = new DataOutputStream( 384 new BufferedOutputStream(new FileOutputStream(file)))) { 385 dataOutputStream.writeInt(wordMap.size()); 386 dataOutputStream.writeInt(layerSize); 387 double[] syn0 = null; 388 for (Entry<String, Neuron> element : wordMap.entrySet()) { 389 dataOutputStream.writeUTF(element.getKey()); 390 syn0 = ((WordNeuron) element.getValue()).syn0; 391 for (double d : syn0) { 392 dataOutputStream.writeFloat(((Double) d).floatValue()); 393 } 394 } 395 } catch (IOException e) { 396 // TODO Auto-generated catch block 397 e.printStackTrace(); 398 } 399 } 400 401 public int getLayerSize() { 402 return layerSize; 403 } 404 405 public void setLayerSize(int layerSize) { 406 this.layerSize = layerSize; 407 } 408 409 public int getWindow() { 410 return window; 411 } 412 413 public void setWindow(int window) { 414 this.window = window; 415 } 416 417 public double getSample() { 418 return sample; 419 } 420 421 public void setSample(double sample) { 422 this.sample = sample; 423 } 424 425 public double getAlpha() { 426 return alpha; 427 } 428 429 public void setAlpha(double alpha) { 430 this.alpha = alpha; 431 this.startingAlpha = alpha; 432 } 433 434 public Boolean getIsCbow() { 435 return isCbow; 436 } 437 438 public void setIsCbow(Boolean isCbow) { 439 this.isCbow = isCbow; 440 } 441 442 public static void main(String[] args) throws IOException { 443 Learn learn = new Learn(); 444 long start = System.currentTimeMillis(); 445 learn.learnFile(new File("library/xh.txt")); 446 System.out.println("use time " + (System.currentTimeMillis() - start)); 447 learn.saveModel(new File("library/javaVector")); 448 449 } 450 }
三、词向量的kmeans聚类 util-----wordKmeans.java
1 package com.ansj.vec.util; 2 3 import java.io.IOException; 4 import java.util.ArrayList; 5 import java.util.Arrays; 6 import java.util.Collections; 7 import java.util.Comparator; 8 import java.util.HashMap; 9 import java.util.Iterator; 10 import java.util.List; 11 import java.util.Map; 12 import java.util.Map.Entry; 13 14 import com.ansj.vec.Word2VEC; 15 /*import com.ansj.vec.domain.WordEntry; 16 import com.ansj.vec.util.WordKmeans.Classes;*/ 17 /** 18 * keanmeans聚类 19 * 20 * @author ansj 21 * 22 */ 23 public class WordKmeans { 24 25 public static void main(String[] args) { 26 Word2VEC vec = new Word2VEC(); 27 try { 28 29 vec.loadJavaModel("C:\Users\le\Desktop\0328-事件相关法律的算法进展\javaSkip1"); 30 System.out.println("中国" + " " +Arrays.toString(vec.getWordVector("中国"))); 31 System.out.println("何润东" + " " +Arrays.toString(vec.getWordVector("何润东"))); 32 System.out.println("足球" + " " + Arrays.toString(vec.getWordVector("足球"))); 33 } catch (IOException e) { 34 // TODO Auto-generated catch block 35 e.printStackTrace(); 36 } 37 System.out.println("load model ok!"); 38 WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), 50, 50); 39 Classes[] explain = wordKmeans.explain(); 40 41 for (int i = 0; i < explain.length; i++) { 42 System.out.println("--------" + i + "---------"); 43 System.out.println(explain[i].getTop(10)); 44 } 45 46 } 47 48 private HashMap<String, float[]> wordMap = null; 49 50 private int iter; 51 52 private Classes[] cArray = null; 53 54 public WordKmeans(HashMap<String, float[]> wordMap, int clcn, int iter) { 55 this.wordMap = wordMap; 56 this.iter = iter; 57 cArray = new Classes[clcn]; 58 } 59 60 public Classes[] explain() { 61 //first 取前clcn个点 62 Iterator<Entry<String, float[]>> iterator = wordMap.entrySet().iterator(); 63 for (int i = 0; i < cArray.length; i++) { 64 Entry<String, float[]> next = iterator.next(); 65 cArray[i] = new Classes(i, next.getValue()); 66 } 67 68 for (int i = 0; i < iter; i++) { 69 for (Classes classes : cArray) { 70 classes.clean(); 71 } 72 73 iterator = wordMap.entrySet().iterator(); 74 while (iterator.hasNext()) { 75 Entry<String, float[]> next = iterator.next(); 76 double miniScore = Double.MAX_VALUE; 77 double tempScore; 78 int classesId = 0; 79 for (Classes classes : cArray) { 80 tempScore = classes.distance(next.getValue()); 81 if (miniScore > tempScore) { 82 miniScore = tempScore; 83 classesId = classes.id; 84 } 85 } 86 cArray[classesId].putValue(next.getKey(), miniScore); 87 } 88 89 for (Classes classes : cArray) { 90 classes.updateCenter(wordMap); 91 } 92 System.out.println("iter " + i + " ok!"); 93 } 94 95 return cArray; 96 } 97 98 public static class Classes { 99 private int id; 100 101 private float[] center; 102 103 public Classes(int id, float[] center) { 104 this.id = id; 105 this.center = center.clone(); 106 } 107 108 Map<String, Double> values = new HashMap<>(); 109 110 public double distance(float[] value) { 111 double sum = 0; 112 for (int i = 0; i < value.length; i++) { 113 sum += (center[i] - value[i])*(center[i] - value[i]) ; 114 } 115 return sum ; 116 } 117 118 public void putValue(String word, double score) { 119 values.put(word, score); 120 } 121 122 /** 123 * 重新计算中心点 124 * @param wordMap 125 */ 126 public void updateCenter(HashMap<String, float[]> wordMap) { 127 for (int i = 0; i < center.length; i++) { 128 center[i] = 0; 129 } 130 float[] value = null; 131 for (String keyWord : values.keySet()) { 132 value = wordMap.get(keyWord); 133 for (int i = 0; i < value.length; i++) { 134 center[i] += value[i]; 135 } 136 } 137 for (int i = 0; i < center.length; i++) { 138 center[i] = center[i] / values.size(); 139 } 140 } 141 142 /** 143 * 清空历史结果 144 */ 145 public void clean() { 146 // TODO Auto-generated method stub 147 values.clear(); 148 } 149 150 /** 151 * 取得每个类别的前n个结果 152 * @param n 153 * @return 154 */ 155 public List<Entry<String, Double>> getTop(int n) { 156 List<Map.Entry<String, Double>> arrayList = new ArrayList<Map.Entry<String, Double>>( 157 values.entrySet()); 158 Collections.sort(arrayList, new Comparator<Map.Entry<String, Double>>() { 159 @Override 160 public int compare(Entry<String, Double> o1, Entry<String, Double> o2) { 161 // TODO Auto-generated method stub 162 return o1.getValue() > o2.getValue() ? 1 : -1; 163 } 164 }); 165 int min = Math.min(n, arrayList.size() - 1); 166 if(min<=1)return Collections.emptyList() ; 167 return arrayList.subList(0, min); 168 } 169 170 } 171 172 }
四、词向量的 util-----huffman.java mapcount.java
1 package com.ansj.vec.util; 2 3 import java.util.Collection; 4 import java.util.List; 5 import java.util.TreeSet; 6 7 import com.ansj.vec.domain.HiddenNeuron; 8 import com.ansj.vec.domain.Neuron; 9 10 /** 11 * 构建Haffman编码树 12 * 13 * @author ansj 14 * 15 */ 16 public class Haffman { 17 private int layerSize; 18 19 public Haffman(int layerSize) { 20 this.layerSize = layerSize; 21 } 22 23 private TreeSet<Neuron> set = new TreeSet<>(); 24 25 public void make(Collection<Neuron> neurons) { 26 set.addAll(neurons); 27 while (set.size() > 1) { 28 merger(); 29 } 30 } 31 32 private void merger() { 33 HiddenNeuron hn = new HiddenNeuron(layerSize); 34 Neuron min1 = set.pollFirst(); 35 Neuron min2 = set.pollFirst(); 36 hn.category = min2.category; 37 hn.freq = min1.freq + min2.freq; 38 min1.parent = hn; 39 min2.parent = hn; 40 min1.code = 0; 41 min2.code = 1; 42 set.add(hn); 43 } 44 45 }
1 // 2 // Source code recreated from a .class file by IntelliJ IDEA 3 // (powered by Fernflower decompiler) 4 // 5 6 package com.ansj.vec.util; 7 8 import java.util.HashMap; 9 import java.util.Iterator; 10 import java.util.Map.Entry; 11 12 public class MapCount<T> { 13 private HashMap<T, Integer> hm = null; 14 15 public MapCount() { 16 this.hm = new HashMap(); 17 } 18 19 public MapCount(int initialCapacity) { 20 this.hm = new HashMap(initialCapacity); 21 } 22 23 public void add(T t, int n) { 24 Integer integer = null; 25 if((integer = (Integer)this.hm.get(t)) != null) { 26 this.hm.put(t, Integer.valueOf(integer.intValue() + n)); 27 } else { 28 this.hm.put(t, Integer.valueOf(n)); 29 } 30 31 } 32 33 public void add(T t) { 34 this.add(t, 1); 35 } 36 37 public int size() { 38 return this.hm.size(); 39 } 40 41 public void remove(T t) { 42 this.hm.remove(t); 43 } 44 45 public HashMap<T, Integer> get() { 46 return this.hm; 47 } 48 49 public String getDic() { 50 Iterator iterator = this.hm.entrySet().iterator(); 51 StringBuilder sb = new StringBuilder(); 52 Entry next = null; 53 54 while(iterator.hasNext()) { 55 next = (Entry)iterator.next(); 56 sb.append(next.getKey()); 57 sb.append(" "); 58 sb.append(next.getValue()); 59 sb.append(" "); 60 } 61 62 return sb.toString(); 63 } 64 65 public static void main(String[] args) { 66 System.out.println(9223372036854775807L); 67 } 68 }
五、词向量的domain包
1 package com.ansj.vec.domain; 2 3 public class HiddenNeuron extends Neuron{ 4 5 public double[] syn1 ; //hidden->out 6 7 public HiddenNeuron(int layerSize){ 8 syn1 = new double[layerSize] ; 9 } 10 11 }
1 package com.ansj.vec.domain; 2 3 public abstract class Neuron implements Comparable<Neuron> { 4 public double freq; 5 public Neuron parent; 6 public int code; 7 // 语料预分类 8 public int category = -1; 9 10 @Override 11 public int compareTo(Neuron neuron) { 12 if (this.category == neuron.category) { 13 if (this.freq > neuron.freq) { 14 return 1; 15 } else { 16 return -1; 17 } 18 } else if (this.category > neuron.category) { 19 return 1; 20 } else { 21 return -1; 22 } 23 } 24 }
1 package com.ansj.vec.domain; 2 3 4 public class WordEntry implements Comparable<WordEntry> { 5 public String name; 6 public float score; 7 8 public WordEntry(String name, float score) { 9 this.name = name; 10 this.score = score; 11 } 12 13 @Override 14 public String toString() { 15 // TODO Auto-generated method stub 16 return this.name + " " + score; 17 } 18 19 @Override 20 public int compareTo(WordEntry o) { 21 // TODO Auto-generated method stub 22 if (this.score < o.score) { 23 return 1; 24 } else { 25 return -1; 26 } 27 } 28 29 }
1 package com.ansj.vec.domain; 2 3 import java.util.Collections; 4 import java.util.LinkedList; 5 import java.util.List; 6 import java.util.Random; 7 8 public class WordNeuron extends Neuron { 9 public String name; 10 public double[] syn0 = null; // input->hidden 11 public List<Neuron> neurons = null;// 路径神经元 12 public int[] codeArr = null; 13 14 public List<Neuron> makeNeurons() { 15 if (neurons != null) { 16 return neurons; 17 } 18 Neuron neuron = this; 19 neurons = new LinkedList<>(); 20 while ((neuron = neuron.parent) != null) { 21 neurons.add(neuron); 22 } 23 Collections.reverse(neurons); 24 codeArr = new int[neurons.size()]; 25 26 for (int i = 1; i < neurons.size(); i++) { 27 codeArr[i - 1] = neurons.get(i).code; 28 } 29 codeArr[codeArr.length - 1] = this.code; 30 31 return neurons; 32 } 33 34 public WordNeuron(String name, double freq, int layerSize) { 35 this.name = name; 36 this.freq = freq; 37 this.syn0 = new double[layerSize]; 38 Random random = new Random(); 39 for (int i = 0; i < syn0.length; i++) { 40 syn0[i] = (random.nextDouble() - 0.5) / layerSize; 41 } 42 } 43 44 /** 45 * 用于有监督的创造hoffman tree 46 * 47 * @param name 48 * @param freq 49 * @param layerSize 50 */ 51 public WordNeuron(String name, double freq, int category, int layerSize) { 52 this.name = name; 53 this.freq = freq; 54 this.syn0 = new double[layerSize]; 55 this.category = category; 56 Random random = new Random(); 57 for (int i = 0; i < syn0.length; i++) { 58 syn0[i] = (random.nextDouble() - 0.5) / layerSize; 59 } 60 } 61 62 }