zoukankan      html  css  js  c++  java
  • word2vec的Java源码【转】

    一、核心代码 word2vec.java

      1 package com.ansj.vec;
      2 
      3 import java.io.*;
      4 import java.lang.reflect.Array;
      5 import java.util.ArrayList;
      6 import java.util.Arrays;
      7 import java.util.Collections;
      8 import java.util.HashMap;
      9 import java.util.List;
     10 import java.util.Map;
     11 import java.util.Map.Entry;
     12 import java.util.Set;
     13 import java.util.TreeSet;
     14 
     15 import com.ansj.vec.domain.WordEntry;
     16 import com.ansj.vec.util.WordKmeans;
     17 import com.ansj.vec.util.WordKmeans.Classes;
     18 
     19 public class Word2VEC {
     20 
     21     public static void main(String[] args) throws IOException {
     22 
     23          //Learn learn = new Learn();
     24         //learn.learnFile(new File("C:\Users\le\Desktop\0328-事件相关法律的算法进展\Result_Country.txt"));
     25         //learn.saveModel(new File("C:\Users\le\Desktop\0328-事件相关法律的算法进展\javaSkip1"));
     26         
     27         Word2VEC vec = new Word2VEC();
     28         vec.loadJavaModel("C:\Users\le\Desktop\0328-事件相关法律的算法进展\javaSkip1");
     29         System.out.println("中国" + "	" +Arrays.toString(vec.getWordVector("中国")));
     30         System.out.println("何润东" + "	" +Arrays.toString(vec.getWordVector("何润东")));
     31         System.out.println("足球" + "	" + Arrays.toString(vec.getWordVector("足球")));
     32 
     33         String str = "中国";
     34         System.out.println(vec.distance(str));
     35         WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), 50, 10);
     36         Classes[] explain = wordKmeans.explain();
     37         for (int i = 0; i < explain.length; i++) {
     38             System.out.println("--------" + i + "---------");
     39             System.out.println(explain[i].getTop(10));
     40         }
     41     }
     42 
     43     private HashMap<String, float[]> wordMap = new HashMap<String, float[]>();
     44 
     45     private int words;
     46     private int size;
     47     private int topNSize = 40;
     48 
     49     /**
     50      * 鍔犺浇妯″瀷
     51      * 
     52      * @param path
     53      *            妯″瀷鐨勮矾寰�
     54      * @throws IOException
     55      */
     56     public void loadGoogleModel(String path) throws IOException {
     57         DataInputStream dis = null;
     58         BufferedInputStream bis = null;
     59         double len = 0;
     60         float vector = 0;
     61         try {
     62             bis = new BufferedInputStream(new FileInputStream(path));
     63             dis = new DataInputStream(bis);
     64             // //璇诲彇璇嶆暟
     65             words = Integer.parseInt(readString(dis));
     66             // //澶у皬
     67             size = Integer.parseInt(readString(dis));
     68             String word;
     69             float[] vectors = null;
     70             for (int i = 0; i < words; i++) {
     71                 word = readString(dis);
     72                 vectors = new float[size];
     73                 len = 0;
     74                 for (int j = 0; j < size; j++) {
     75                     vector = readFloat(dis);
     76                     len += vector * vector;
     77                     vectors[j] = (float) vector;
     78                 }
     79                 len = Math.sqrt(len);
     80 
     81                 for (int j = 0; j < size; j++) {
     82                     vectors[j] /= len;
     83                 }
     84 
     85                 wordMap.put(word, vectors);
     86                 dis.read();
     87             }
     88         } finally {
     89             bis.close();
     90             dis.close();
     91         }
     92     }
     93 
     94     /**
     95      * 鍔犺浇妯″瀷
     96      * 
     97      * @param path
     98      *            妯″瀷鐨勮矾寰�
     99      * @throws IOException
    100      */
    101     public void loadJavaModel(String path) throws IOException {
    102         try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(path)))) {
    103             words = dis.readInt();
    104             size = dis.readInt();
    105 
    106             float vector = 0;
    107 
    108             String key = null;
    109             float[] value = null;
    110             for (int i = 0; i < words; i++) {
    111                 double len = 0;
    112                 key = dis.readUTF();
    113                 value = new float[size];
    114                 for (int j = 0; j < size; j++) {
    115                     vector = dis.readFloat();
    116                     len += vector * vector;
    117                     value[j] = vector;
    118                 }
    119 
    120                 len = Math.sqrt(len);
    121 
    122                 for (int j = 0; j < size; j++) {
    123                     value[j] /= len;
    124                 }
    125                 wordMap.put(key, value);
    126             }
    127 
    128         }
    129     }
    130 
    131     private static final int MAX_SIZE = 50;
    132 
    133     /**
    134      * 杩戜箟璇�
    135      * 
    136      * @return
    137      */
    138     public TreeSet<WordEntry> analogy(String word0, String word1, String word2) {
    139         float[] wv0 = getWordVector(word0);
    140         float[] wv1 = getWordVector(word1);
    141         float[] wv2 = getWordVector(word2);
    142 
    143         if (wv1 == null || wv2 == null || wv0 == null) {
    144             return null;
    145         }
    146         float[] wordVector = new float[size];
    147         for (int i = 0; i < size; i++) {
    148             wordVector[i] = wv1[i] - wv0[i] + wv2[i];
    149         }
    150         float[] tempVector;
    151         String name;
    152         List<WordEntry> wordEntrys = new ArrayList<WordEntry>(topNSize);
    153         for (Entry<String, float[]> entry : wordMap.entrySet()) {
    154             name = entry.getKey();
    155             if (name.equals(word0) || name.equals(word1) || name.equals(word2)) {
    156                 continue;
    157             }
    158             float dist = 0;
    159             tempVector = entry.getValue();
    160             for (int i = 0; i < wordVector.length; i++) {
    161                 dist += wordVector[i] * tempVector[i];
    162             }
    163             insertTopN(name, dist, wordEntrys);
    164         }
    165         return new TreeSet<WordEntry>(wordEntrys);
    166     }
    167 
    168     private void insertTopN(String name, float score, List<WordEntry> wordsEntrys) {
    169         // TODO Auto-generated method stub
    170         if (wordsEntrys.size() < topNSize) {
    171             wordsEntrys.add(new WordEntry(name, score));
    172             return;
    173         }
    174         float min = Float.MAX_VALUE;
    175         int minOffe = 0;
    176         for (int i = 0; i < topNSize; i++) {
    177             WordEntry wordEntry = wordsEntrys.get(i);
    178             if (min > wordEntry.score) {
    179                 min = wordEntry.score;
    180                 minOffe = i;
    181             }
    182         }
    183 
    184         if (score > min) {
    185             wordsEntrys.set(minOffe, new WordEntry(name, score));
    186         }
    187 
    188     }
    189 
    190     public Set<WordEntry> distance(String queryWord) {
    191 
    192         float[] center = wordMap.get(queryWord);
    193         if (center == null) {
    194             return Collections.emptySet();
    195         }
    196 
    197         int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
    198         TreeSet<WordEntry> result = new TreeSet<WordEntry>();
    199 
    200         double min = Float.MIN_VALUE;
    201         for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
    202             float[] vector = entry.getValue();
    203             float dist = 0;
    204             for (int i = 0; i < vector.length; i++) {
    205                 dist += center[i] * vector[i];
    206             }
    207 
    208             if (dist > min) {
    209                 result.add(new WordEntry(entry.getKey(), dist));
    210                 if (resultSize < result.size()) {
    211                     result.pollLast();
    212                 }
    213                 min = result.last().score;
    214             }
    215         }
    216         result.pollFirst();
    217 
    218         return result;
    219     }
    220 
    221     public Set<WordEntry> distance(List<String> words) {
    222 
    223         float[] center = null;
    224         for (String word : words) {
    225             center = sum(center, wordMap.get(word));
    226         }
    227 
    228         if (center == null) {
    229             return Collections.emptySet();
    230         }
    231 
    232         int resultSize = wordMap.size() < topNSize ? wordMap.size() : topNSize;
    233         TreeSet<WordEntry> result = new TreeSet<WordEntry>();
    234 
    235         double min = Float.MIN_VALUE;
    236         for (Map.Entry<String, float[]> entry : wordMap.entrySet()) {
    237             float[] vector = entry.getValue();
    238             float dist = 0;
    239             for (int i = 0; i < vector.length; i++) {
    240                 dist += center[i] * vector[i];
    241             }
    242 
    243             if (dist > min) {
    244                 result.add(new WordEntry(entry.getKey(), dist));
    245                 if (resultSize < result.size()) {
    246                     result.pollLast();
    247                 }
    248                 min = result.last().score;
    249             }
    250         }
    251         result.pollFirst();
    252 
    253         return result;
    254     }
    255 
    256     private float[] sum(float[] center, float[] fs) {
    257         // TODO Auto-generated method stub
    258 
    259         if (center == null && fs == null) {
    260             return null;
    261         }
    262 
    263         if (fs == null) {
    264             return center;
    265         }
    266 
    267         if (center == null) {
    268             return fs;
    269         }
    270 
    271         for (int i = 0; i < fs.length; i++) {
    272             center[i] += fs[i];
    273         }
    274 
    275         return center;
    276     }
    277 
    278     /**
    279      * 寰楀埌璇嶅悜閲�
    280      * 
    281      * @param word
    282      * @return
    283      */
    284     public float[] getWordVector(String word) {
    285         return wordMap.get(word);
    286     }
    287 
    288     public static float readFloat(InputStream is) throws IOException {
    289         byte[] bytes = new byte[4];
    290         is.read(bytes);
    291         return getFloat(bytes);
    292     }
    293 
    294     /**
    295      * 璇诲彇涓�涓猣loat
    296      * 
    297      * @param b
    298      * @return
    299      */
    300     public static float getFloat(byte[] b) {
    301         int accum = 0;
    302         accum = accum | (b[0] & 0xff) << 0;
    303         accum = accum | (b[1] & 0xff) << 8;
    304         accum = accum | (b[2] & 0xff) << 16;
    305         accum = accum | (b[3] & 0xff) << 24;
    306         return Float.intBitsToFloat(accum);
    307     }
    308 
    309     /**
    310      * 璇诲彇涓�涓�瓧绗︿覆
    311      * 
    312      * @param dis
    313      * @return
    314      * @throws IOException
    315      */
    316     private static String readString(DataInputStream dis) throws IOException {
    317         // TODO Auto-generated method stub
    318         byte[] bytes = new byte[MAX_SIZE];
    319         byte b = dis.readByte();
    320         int i = -1;
    321         StringBuilder sb = new StringBuilder();
    322         while (b != 32 && b != 10) {
    323             i++;
    324             bytes[i] = b;
    325             b = dis.readByte();
    326             if (i == 49) {
    327                 sb.append(new String(bytes));
    328                 i = -1;
    329                 bytes = new byte[MAX_SIZE];
    330             }
    331         }
    332         sb.append(new String(bytes, 0, i + 1));
    333         return sb.toString();
    334     }
    335 
    336     public int getTopNSize() {
    337         return topNSize;
    338     }
    339 
    340     public void setTopNSize(int topNSize) {
    341         this.topNSize = topNSize;
    342     }
    343 
    344     public HashMap<String, float[]> getWordMap() {
    345         return wordMap;
    346     }
    347 
    348     public int getWords() {
    349         return words;
    350     }
    351 
    352     public int getSize() {
    353         return size;
    354     }
    355 
    356 }

     二、词向量-模型学习代码learn.java

      1 package com.ansj.vec;
      2 
      3 import java.io.BufferedOutputStream;
      4 import java.io.BufferedReader;
      5 import java.io.DataOutputStream;
      6 import java.io.File;
      7 import java.io.FileInputStream;
      8 import java.io.FileNotFoundException;
      9 import java.io.FileOutputStream;
     10 import java.io.IOException;
     11 import java.io.InputStreamReader;
     12 import java.util.ArrayList;
     13 import java.util.HashMap;
     14 import java.util.List;
     15 import java.util.Map;
     16 import java.util.Map.Entry;
     17 
     18 import com.ansj.vec.util.MapCount;
     19 import com.ansj.vec.domain.HiddenNeuron;
     20 import com.ansj.vec.domain.Neuron;
     21 import com.ansj.vec.domain.WordNeuron;
     22 import com.ansj.vec.util.Haffman;
     23 
     24 public class Learn {
     25 
     26   private Map<String, Neuron> wordMap = new HashMap<>();
     27   /**
     28    * 训练多少个特征
     29    */
     30   private int layerSize = 200;
     31 
     32   /**
     33    * 上下文窗口大小
     34    */
     35   private int window = 5;
     36 
     37   private double sample = 1e-3;
     38   private double alpha = 0.025;
     39   private double startingAlpha = alpha;
     40 
     41   public int EXP_TABLE_SIZE = 1000;
     42 
     43   private Boolean isCbow = false;
     44 
     45   private double[] expTable = new double[EXP_TABLE_SIZE];
     46 
     47   private int trainWordsCount = 0;
     48 
     49   private int MAX_EXP = 6;
     50 
     51   public Learn(Boolean isCbow, Integer layerSize, Integer window, Double alpha,
     52       Double sample) {
     53     createExpTable();
     54     if (isCbow != null) {
     55       this.isCbow = isCbow;
     56     }
     57     if (layerSize != null)
     58       this.layerSize = layerSize;
     59     if (window != null)
     60       this.window = window;
     61     if (alpha != null)
     62       this.alpha = alpha;
     63     if (sample != null)
     64       this.sample = sample;
     65   }
     66 
     67   public Learn() {
     68     createExpTable();
     69   }
     70 
     71   /**
     72    * trainModel
     73    * 
     74    * @throws IOException
     75    */
     76   private void trainModel(File file) throws IOException {
     77     try (BufferedReader br = new BufferedReader(new InputStreamReader(
     78         new FileInputStream(file)))) {
     79       String temp = null;
     80       long nextRandom = 5;
     81       int wordCount = 0;
     82       int lastWordCount = 0;
     83       int wordCountActual = 0;
     84       while ((temp = br.readLine()) != null) {
     85         if (wordCount - lastWordCount > 10000) {
     86           System.out.println("alpha:" + alpha + "	Progress: "
     87               + (int) (wordCountActual / (double) (trainWordsCount + 1) * 100)
     88               + "%");
     89           wordCountActual += wordCount - lastWordCount;
     90           lastWordCount = wordCount;
     91           alpha = startingAlpha
     92               * (1 - wordCountActual / (double) (trainWordsCount + 1));
     93           if (alpha < startingAlpha * 0.0001) {
     94             alpha = startingAlpha * 0.0001;
     95           }
     96         }
     97         String[] strs = temp.split(" ");
     98         wordCount += strs.length;
     99         List<WordNeuron> sentence = new ArrayList<WordNeuron>();
    100         for (int i = 0; i < strs.length; i++) {
    101           Neuron entry = wordMap.get(strs[i]);
    102           if (entry == null) {
    103             continue;
    104           }
    105           // The subsampling randomly discards frequent words while keeping the
    106           // ranking same
    107           if (sample > 0) {
    108             double ran = (Math.sqrt(entry.freq / (sample * trainWordsCount)) + 1)
    109                 * (sample * trainWordsCount) / entry.freq;
    110             nextRandom = nextRandom * 25214903917L + 11;
    111             if (ran < (nextRandom & 0xFFFF) / (double) 65536) {
    112               continue;
    113             }
    114           }
    115           sentence.add((WordNeuron) entry);
    116         }
    117 
    118         for (int index = 0; index < sentence.size(); index++) {
    119           nextRandom = nextRandom * 25214903917L + 11;
    120           if (isCbow) {
    121             cbowGram(index, sentence, (int) nextRandom % window);
    122           } else {
    123             skipGram(index, sentence, (int) nextRandom % window);
    124           }
    125         }
    126 
    127       }
    128       System.out.println("Vocab size: " + wordMap.size());
    129       System.out.println("Words in train file: " + trainWordsCount);
    130       System.out.println("sucess train over!");
    131     }
    132   }
    133 
    134   /**
    135    * skip gram 模型训练
    136    * 
    137    * @param sentence
    138    * @param neu1
    139    */
    140   private void skipGram(int index, List<WordNeuron> sentence, int b) {
    141     // TODO Auto-generated method stub
    142     WordNeuron word = sentence.get(index);
    143     int a, c = 0;
    144     for (a = b; a < window * 2 + 1 - b; a++) {
    145       if (a == window) {
    146         continue;
    147       }
    148       c = index - window + a;
    149       if (c < 0 || c >= sentence.size()) {
    150         continue;
    151       }
    152 
    153       double[] neu1e = new double[layerSize];// 误差项
    154       // HIERARCHICAL SOFTMAX
    155       List<Neuron> neurons = word.neurons;
    156       WordNeuron we = sentence.get(c);
    157       for (int i = 0; i < neurons.size(); i++) {
    158         HiddenNeuron out = (HiddenNeuron) neurons.get(i);
    159         double f = 0;
    160         // Propagate hidden -> output
    161         for (int j = 0; j < layerSize; j++) {
    162           f += we.syn0[j] * out.syn1[j];
    163         }
    164         if (f <= -MAX_EXP || f >= MAX_EXP) {
    165           continue;
    166         } else {
    167           f = (f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2);
    168           f = expTable[(int) f];
    169         }
    170         // 'g' is the gradient multiplied by the learning rate
    171         double g = (1 - word.codeArr[i] - f) * alpha;
    172         // Propagate errors output -> hidden
    173         for (c = 0; c < layerSize; c++) {
    174           neu1e[c] += g * out.syn1[c];
    175         }
    176         // Learn weights hidden -> output
    177         for (c = 0; c < layerSize; c++) {
    178           out.syn1[c] += g * we.syn0[c];
    179         }
    180       }
    181 
    182       // Learn weights input -> hidden
    183       for (int j = 0; j < layerSize; j++) {
    184         we.syn0[j] += neu1e[j];
    185       }
    186     }
    187 
    188   }
    189 
    190   /**
    191    * 词袋模型
    192    * 
    193    * @param index
    194    * @param sentence
    195    * @param b
    196    */
    197   private void cbowGram(int index, List<WordNeuron> sentence, int b) {
    198     WordNeuron word = sentence.get(index);
    199     int a, c = 0;
    200 
    201     List<Neuron> neurons = word.neurons;
    202     double[] neu1e = new double[layerSize];// 误差项
    203     double[] neu1 = new double[layerSize];// 误差项
    204     WordNeuron last_word;
    205 
    206     for (a = b; a < window * 2 + 1 - b; a++)
    207       if (a != window) {
    208         c = index - window + a;
    209         if (c < 0)
    210           continue;
    211         if (c >= sentence.size())
    212           continue;
    213         last_word = sentence.get(c);
    214         if (last_word == null)
    215           continue;
    216         for (c = 0; c < layerSize; c++)
    217           neu1[c] += last_word.syn0[c];
    218       }
    219 
    220     // HIERARCHICAL SOFTMAX
    221     for (int d = 0; d < neurons.size(); d++) {
    222       HiddenNeuron out = (HiddenNeuron) neurons.get(d);
    223       double f = 0;
    224       // Propagate hidden -> output
    225       for (c = 0; c < layerSize; c++)
    226         f += neu1[c] * out.syn1[c];
    227       if (f <= -MAX_EXP)
    228         continue;
    229       else if (f >= MAX_EXP)
    230         continue;
    231       else
    232         f = expTable[(int) ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2))];
    233       // 'g' is the gradient multiplied by the learning rate
    234       // double g = (1 - word.codeArr[d] - f) * alpha;
    235       // double g = f*(1-f)*( word.codeArr[i] - f) * alpha;
    236       double g = f * (1 - f) * (word.codeArr[d] - f) * alpha;
    237       //
    238       for (c = 0; c < layerSize; c++) {
    239         neu1e[c] += g * out.syn1[c];
    240       }
    241       // Learn weights hidden -> output
    242       for (c = 0; c < layerSize; c++) {
    243         out.syn1[c] += g * neu1[c];
    244       }
    245     }
    246     for (a = b; a < window * 2 + 1 - b; a++) {
    247       if (a != window) {
    248         c = index - window + a;
    249         if (c < 0)
    250           continue;
    251         if (c >= sentence.size())
    252           continue;
    253         last_word = sentence.get(c);
    254         if (last_word == null)
    255           continue;
    256         for (c = 0; c < layerSize; c++)
    257           last_word.syn0[c] += neu1e[c];
    258       }
    259 
    260     }
    261   }
    262 
    263   /**
    264    * 统计词频
    265    * 
    266    * @param file
    267    * @throws IOException
    268    */
    269   private void readVocab(File file) throws IOException {
    270     MapCount<String> mc = new MapCount<>();
    271     try (BufferedReader br = new BufferedReader(new InputStreamReader(
    272         new FileInputStream(file)))) {
    273       String temp = null;
    274       while ((temp = br.readLine()) != null) {
    275         String[] split = temp.split(" ");
    276         trainWordsCount += split.length;
    277         for (String string : split) {
    278           mc.add(string);
    279         }
    280       }
    281     }
    282     for (Entry<String, Integer> element : mc.get().entrySet()) {
    283       wordMap.put(element.getKey(), new WordNeuron(element.getKey(),
    284           (double) element.getValue() / mc.size(), layerSize));
    285     }
    286   }
    287 
    288   /**
    289    * 对文本进行预分类
    290    * 
    291    * @param files
    292    * @throws IOException
    293    * @throws FileNotFoundException
    294    */
    295   private void readVocabWithSupervised(File[] files) throws IOException {
    296     for (int category = 0; category < files.length; category++) {
    297       // 对多个文件学习
    298       MapCount<String> mc = new MapCount<>();
    299       try (BufferedReader br = new BufferedReader(new InputStreamReader(
    300           new FileInputStream(files[category])))) {
    301         String temp = null;
    302         while ((temp = br.readLine()) != null) {
    303           String[] split = temp.split(" ");
    304           trainWordsCount += split.length;
    305           for (String string : split) {
    306             mc.add(string);
    307           }
    308         }
    309       }
    310       for (Entry<String, Integer> element : mc.get().entrySet()) {
    311         double tarFreq = (double) element.getValue() / mc.size();
    312         if (wordMap.get(element.getKey()) != null) {
    313           double srcFreq = wordMap.get(element.getKey()).freq;
    314           if (srcFreq >= tarFreq) {
    315             continue;
    316           } else {
    317             Neuron wordNeuron = wordMap.get(element.getKey());
    318             wordNeuron.category = category;
    319             wordNeuron.freq = tarFreq;
    320           }
    321         } else {
    322           wordMap.put(element.getKey(), new WordNeuron(element.getKey(),
    323               tarFreq, category, layerSize));
    324         }
    325       }
    326     }
    327   }
    328 
    329   /**
    330    * Precompute the exp() table f(x) = x / (x + 1)
    331    */
    332   private void createExpTable() {
    333     for (int i = 0; i < EXP_TABLE_SIZE; i++) {
    334       expTable[i] = Math.exp(((i / (double) EXP_TABLE_SIZE * 2 - 1) * MAX_EXP));
    335       expTable[i] = expTable[i] / (expTable[i] + 1);
    336     }
    337   }
    338 
    339   /**
    340    * 根据文件学习
    341    * 
    342    * @param file
    343    * @throws IOException
    344    */
    345   public void learnFile(File file) throws IOException {
    346     readVocab(file);
    347     new Haffman(layerSize).make(wordMap.values());
    348 
    349     // 查找每个神经元
    350     for (Neuron neuron : wordMap.values()) {
    351       ((WordNeuron) neuron).makeNeurons();
    352     }
    353 
    354     trainModel(file);
    355   }
    356 
    357   /**
    358    * 根据预分类的文件学习
    359    * 
    360    * @param summaryFile
    361    *          合并文件
    362    * @param classifiedFiles
    363    *          分类文件
    364    * @throws IOException
    365    */
    366   public void learnFile(File summaryFile, File[] classifiedFiles)
    367       throws IOException {
    368     readVocabWithSupervised(classifiedFiles);
    369     new Haffman(layerSize).make(wordMap.values());
    370     // 查找每个神经元
    371     for (Neuron neuron : wordMap.values()) {
    372       ((WordNeuron) neuron).makeNeurons();
    373     }
    374     trainModel(summaryFile);
    375   }
    376 
    377   /**
    378    * 保存模型
    379    */
    380   public void saveModel(File file) {
    381     // TODO Auto-generated method stub
    382 
    383     try (DataOutputStream dataOutputStream = new DataOutputStream(
    384         new BufferedOutputStream(new FileOutputStream(file)))) {
    385       dataOutputStream.writeInt(wordMap.size());
    386       dataOutputStream.writeInt(layerSize);
    387       double[] syn0 = null;
    388       for (Entry<String, Neuron> element : wordMap.entrySet()) {
    389         dataOutputStream.writeUTF(element.getKey());
    390         syn0 = ((WordNeuron) element.getValue()).syn0;
    391         for (double d : syn0) {
    392           dataOutputStream.writeFloat(((Double) d).floatValue());
    393         }
    394       }
    395     } catch (IOException e) {
    396       // TODO Auto-generated catch block
    397       e.printStackTrace();
    398     }
    399   }
    400 
    401   public int getLayerSize() {
    402     return layerSize;
    403   }
    404 
    405   public void setLayerSize(int layerSize) {
    406     this.layerSize = layerSize;
    407   }
    408 
    409   public int getWindow() {
    410     return window;
    411   }
    412 
    413   public void setWindow(int window) {
    414     this.window = window;
    415   }
    416 
    417   public double getSample() {
    418     return sample;
    419   }
    420 
    421   public void setSample(double sample) {
    422     this.sample = sample;
    423   }
    424 
    425   public double getAlpha() {
    426     return alpha;
    427   }
    428 
    429   public void setAlpha(double alpha) {
    430     this.alpha = alpha;
    431     this.startingAlpha = alpha;
    432   }
    433 
    434   public Boolean getIsCbow() {
    435     return isCbow;
    436   }
    437 
    438   public void setIsCbow(Boolean isCbow) {
    439     this.isCbow = isCbow;
    440   }
    441 
    442   public static void main(String[] args) throws IOException {
    443     Learn learn = new Learn();
    444     long start = System.currentTimeMillis();
    445     learn.learnFile(new File("library/xh.txt"));
    446     System.out.println("use time " + (System.currentTimeMillis() - start));
    447     learn.saveModel(new File("library/javaVector"));
    448 
    449   }
    450 }

    三、词向量的kmeans聚类 util-----wordKmeans.java

      1 package com.ansj.vec.util;
      2 
      3 import java.io.IOException;
      4 import java.util.ArrayList;
      5 import java.util.Arrays;
      6 import java.util.Collections;
      7 import java.util.Comparator;
      8 import java.util.HashMap;
      9 import java.util.Iterator;
     10 import java.util.List;
     11 import java.util.Map;
     12 import java.util.Map.Entry;
     13 
     14 import com.ansj.vec.Word2VEC;
     15 /*import com.ansj.vec.domain.WordEntry;
     16 import com.ansj.vec.util.WordKmeans.Classes;*/
     17 /**
     18  * keanmeans聚类
     19  * 
     20  * @author ansj
     21  * 
     22  */
     23 public class WordKmeans {
     24 
     25     public static void main(String[] args) {
     26         Word2VEC vec = new Word2VEC();
     27         try {
     28             
     29             vec.loadJavaModel("C:\Users\le\Desktop\0328-事件相关法律的算法进展\javaSkip1");
     30             System.out.println("中国" + "	" +Arrays.toString(vec.getWordVector("中国")));
     31             System.out.println("何润东" + "	" +Arrays.toString(vec.getWordVector("何润东")));
     32             System.out.println("足球" + "	" + Arrays.toString(vec.getWordVector("足球")));
     33         } catch (IOException e) {
     34             // TODO Auto-generated catch block
     35             e.printStackTrace();
     36         }
     37         System.out.println("load model ok!");
     38         WordKmeans wordKmeans = new WordKmeans(vec.getWordMap(), 50, 50);
     39         Classes[] explain = wordKmeans.explain();
     40 
     41         for (int i = 0; i < explain.length; i++) {
     42             System.out.println("--------" + i + "---------");
     43             System.out.println(explain[i].getTop(10));
     44         }
     45 
     46     }
     47 
     48     private HashMap<String, float[]> wordMap = null;
     49 
     50     private int iter;
     51 
     52     private Classes[] cArray = null;
     53 
     54     public WordKmeans(HashMap<String, float[]> wordMap, int clcn, int iter) {
     55         this.wordMap = wordMap;
     56         this.iter = iter;
     57         cArray = new Classes[clcn];
     58     }
     59 
     60     public Classes[] explain() {
     61         //first 取前clcn个点
     62         Iterator<Entry<String, float[]>> iterator = wordMap.entrySet().iterator();
     63         for (int i = 0; i < cArray.length; i++) {
     64             Entry<String, float[]> next = iterator.next();
     65             cArray[i] = new Classes(i, next.getValue());
     66         }
     67 
     68         for (int i = 0; i < iter; i++) {
     69             for (Classes classes : cArray) {
     70                 classes.clean();
     71             }
     72 
     73             iterator = wordMap.entrySet().iterator();
     74             while (iterator.hasNext()) {
     75                 Entry<String, float[]> next = iterator.next();
     76                 double miniScore = Double.MAX_VALUE;
     77                 double tempScore;
     78                 int classesId = 0;
     79                 for (Classes classes : cArray) {
     80                     tempScore = classes.distance(next.getValue());
     81                     if (miniScore > tempScore) {
     82                         miniScore = tempScore;
     83                         classesId = classes.id;
     84                     }
     85                 }
     86                 cArray[classesId].putValue(next.getKey(), miniScore);
     87             }
     88 
     89             for (Classes classes : cArray) {
     90                 classes.updateCenter(wordMap);
     91             }
     92             System.out.println("iter " + i + " ok!");
     93         }
     94 
     95         return cArray;
     96     }
     97 
     98     public static class Classes {
     99         private int id;
    100 
    101         private float[] center;
    102 
    103         public Classes(int id, float[] center) {
    104             this.id = id;
    105             this.center = center.clone();
    106         }
    107 
    108         Map<String, Double> values = new HashMap<>();
    109 
    110         public double distance(float[] value) {
    111             double sum = 0;
    112             for (int i = 0; i < value.length; i++) {
    113                 sum += (center[i] - value[i])*(center[i] - value[i]) ;
    114             }
    115             return sum ;
    116         }
    117 
    118         public void putValue(String word, double score) {
    119             values.put(word, score);
    120         }
    121 
    122         /**
    123          * 重新计算中心点
    124          * @param wordMap
    125          */
    126         public void updateCenter(HashMap<String, float[]> wordMap) {
    127             for (int i = 0; i < center.length; i++) {
    128                 center[i] = 0;
    129             }
    130             float[] value = null;
    131             for (String keyWord : values.keySet()) {
    132                 value = wordMap.get(keyWord);
    133                 for (int i = 0; i < value.length; i++) {
    134                     center[i] += value[i];
    135                 }
    136             }
    137             for (int i = 0; i < center.length; i++) {
    138                 center[i] = center[i] / values.size();
    139             }
    140         }
    141 
    142         /**
    143          * 清空历史结果
    144          */
    145         public void clean() {
    146             // TODO Auto-generated method stub
    147             values.clear();
    148         }
    149 
    150         /**
    151          * 取得每个类别的前n个结果
    152          * @param n
    153          * @return 
    154          */
    155         public List<Entry<String, Double>> getTop(int n) {
    156             List<Map.Entry<String, Double>> arrayList = new ArrayList<Map.Entry<String, Double>>(
    157                 values.entrySet());
    158             Collections.sort(arrayList, new Comparator<Map.Entry<String, Double>>() {
    159                 @Override
    160                 public int compare(Entry<String, Double> o1, Entry<String, Double> o2) {
    161                     // TODO Auto-generated method stub
    162                     return o1.getValue() > o2.getValue() ? 1 : -1;
    163                 }
    164             });
    165             int min = Math.min(n, arrayList.size() - 1);
    166             if(min<=1)return Collections.emptyList() ;
    167             return arrayList.subList(0, min);
    168         }
    169 
    170     }
    171 
    172 }

    四、词向量的 util-----huffman.java  mapcount.java

     1 package com.ansj.vec.util;
     2 
     3 import java.util.Collection;
     4 import java.util.List;
     5 import java.util.TreeSet;
     6 
     7 import com.ansj.vec.domain.HiddenNeuron;
     8 import com.ansj.vec.domain.Neuron;
     9 
    10 /**
    11  * 构建Haffman编码树
    12  * 
    13  * @author ansj
    14  *
    15  */
    16 public class Haffman {
    17   private int layerSize;
    18 
    19   public Haffman(int layerSize) {
    20     this.layerSize = layerSize;
    21   }
    22 
    23   private TreeSet<Neuron> set = new TreeSet<>();
    24 
    25   public void make(Collection<Neuron> neurons) {
    26     set.addAll(neurons);
    27     while (set.size() > 1) {
    28       merger();
    29     }
    30   }
    31 
    32   private void merger() {
    33     HiddenNeuron hn = new HiddenNeuron(layerSize);
    34     Neuron min1 = set.pollFirst();
    35     Neuron min2 = set.pollFirst();
    36     hn.category = min2.category;
    37     hn.freq = min1.freq + min2.freq;
    38     min1.parent = hn;
    39     min2.parent = hn;
    40     min1.code = 0;
    41     min2.code = 1;
    42     set.add(hn);
    43   }
    44 
    45 }
     1 //
     2 // Source code recreated from a .class file by IntelliJ IDEA
     3 // (powered by Fernflower decompiler)
     4 //
     5 
     6 package com.ansj.vec.util;
     7 
     8 import java.util.HashMap;
     9 import java.util.Iterator;
    10 import java.util.Map.Entry;
    11 
    12 public class MapCount<T> {
    13     private HashMap<T, Integer> hm = null;
    14 
    15     public MapCount() {
    16         this.hm = new HashMap();
    17     }
    18 
    19     public MapCount(int initialCapacity) {
    20         this.hm = new HashMap(initialCapacity);
    21     }
    22 
    23     public void add(T t, int n) {
    24         Integer integer = null;
    25         if((integer = (Integer)this.hm.get(t)) != null) {
    26             this.hm.put(t, Integer.valueOf(integer.intValue() + n));
    27         } else {
    28             this.hm.put(t, Integer.valueOf(n));
    29         }
    30 
    31     }
    32 
    33     public void add(T t) {
    34         this.add(t, 1);
    35     }
    36 
    37     public int size() {
    38         return this.hm.size();
    39     }
    40 
    41     public void remove(T t) {
    42         this.hm.remove(t);
    43     }
    44 
    45     public HashMap<T, Integer> get() {
    46         return this.hm;
    47     }
    48 
    49     public String getDic() {
    50         Iterator iterator = this.hm.entrySet().iterator();
    51         StringBuilder sb = new StringBuilder();
    52         Entry next = null;
    53 
    54         while(iterator.hasNext()) {
    55             next = (Entry)iterator.next();
    56             sb.append(next.getKey());
    57             sb.append("	");
    58             sb.append(next.getValue());
    59             sb.append("
    ");
    60         }
    61 
    62         return sb.toString();
    63     }
    64 
    65     public static void main(String[] args) {
    66         System.out.println(9223372036854775807L);
    67     }
    68 }

    五、词向量的domain包

     1 package com.ansj.vec.domain;
     2 
     3 public class HiddenNeuron extends Neuron{
     4     
     5     public double[] syn1 ; //hidden->out
     6     
     7     public HiddenNeuron(int layerSize){
     8         syn1 = new double[layerSize] ;
     9     }
    10     
    11 }
     1 package com.ansj.vec.domain;
     2 
     3 public abstract class Neuron implements Comparable<Neuron> {
     4   public double freq;
     5   public Neuron parent;
     6   public int code;
     7   // 语料预分类
     8   public int category = -1;
     9 
    10   @Override
    11   public int compareTo(Neuron neuron) {
    12     if (this.category == neuron.category) {
    13       if (this.freq > neuron.freq) {
    14         return 1;
    15       } else {
    16         return -1;
    17       }
    18     } else if (this.category > neuron.category) {
    19       return 1;
    20     } else {
    21       return -1;
    22     }
    23   }
    24 }
     1 package com.ansj.vec.domain;
     2 
     3 
     4 public class WordEntry implements Comparable<WordEntry> {
     5     public String name;
     6     public float score;
     7 
     8     public WordEntry(String name, float score) {
     9         this.name = name;
    10         this.score = score;
    11     }
    12 
    13     @Override
    14     public String toString() {
    15         // TODO Auto-generated method stub
    16         return this.name + "	" + score;
    17     }
    18 
    19     @Override
    20     public int compareTo(WordEntry o) {
    21         // TODO Auto-generated method stub
    22         if (this.score < o.score) {
    23             return 1;
    24         } else {
    25             return -1;
    26         }
    27     }
    28 
    29 }
     1 package com.ansj.vec.domain;
     2 
     3 import java.util.Collections;
     4 import java.util.LinkedList;
     5 import java.util.List;
     6 import java.util.Random;
     7 
     8 public class WordNeuron extends Neuron {
     9   public String name;
    10   public double[] syn0 = null; // input->hidden
    11   public List<Neuron> neurons = null;// 路径神经元
    12   public int[] codeArr = null;
    13 
    14   public List<Neuron> makeNeurons() {
    15     if (neurons != null) {
    16       return neurons;
    17     }
    18     Neuron neuron = this;
    19     neurons = new LinkedList<>();
    20     while ((neuron = neuron.parent) != null) {
    21       neurons.add(neuron);
    22     }
    23     Collections.reverse(neurons);
    24     codeArr = new int[neurons.size()];
    25 
    26     for (int i = 1; i < neurons.size(); i++) {
    27       codeArr[i - 1] = neurons.get(i).code;
    28     }
    29     codeArr[codeArr.length - 1] = this.code;
    30 
    31     return neurons;
    32   }
    33 
    34   public WordNeuron(String name, double freq, int layerSize) {
    35     this.name = name;
    36     this.freq = freq;
    37     this.syn0 = new double[layerSize];
    38     Random random = new Random();
    39     for (int i = 0; i < syn0.length; i++) {
    40       syn0[i] = (random.nextDouble() - 0.5) / layerSize;
    41     }
    42   }
    43 
    44   /**
    45    * 用于有监督的创造hoffman tree
    46    * 
    47    * @param name
    48    * @param freq
    49    * @param layerSize
    50    */
    51   public WordNeuron(String name, double freq, int category, int layerSize) {
    52     this.name = name;
    53     this.freq = freq;
    54     this.syn0 = new double[layerSize];
    55     this.category = category;
    56     Random random = new Random();
    57     for (int i = 0; i < syn0.length; i++) {
    58       syn0[i] = (random.nextDouble() - 0.5) / layerSize;
    59     }
    60   }
    61 
    62 }
  • 相关阅读:
    基于消息摆渡节点的DTN路由
    A DTN Congestion Mechanism Based on Distributed Storage
    $(formId).autocomplete is not a function
    [Microsoft][SQLServer 2000 Driver for JDBC]Error establishing socket.
    ajax提交的中文便会变成乱码
    Ajax原理
    Need to specify class name in environment or system property, or as an applet parameter, or in an application resource file: java.naming.factory.init
    JavaScript 无符号位移运算符 >>> 三个大于号 的使用方法
    普通按钮提交
    jsp的一些基本操作
  • 原文地址:https://www.cnblogs.com/Lxiaojiang/p/6644699.html
Copyright © 2011-2022 走看看