zoukankan      html  css  js  c++  java
  • 使用 DL4J 训练中文词向量

    使用 DL4J 训练中文词向量

    1 预处理

    对中文语料的预处理,主要包括:分词、去停用词以及一些根据实际场景制定的规则。

    package ai.mole.test;
    
    import org.ansj.domain.Term;
    import org.ansj.splitWord.analysis.ToAnalysis;
    import org.nlpcn.commons.lang.tire.domain.Forest;
    import org.nlpcn.commons.lang.tire.library.Library;
    
    import java.io.*;
    import java.util.LinkedList;
    import java.util.List;
    import java.util.regex.Pattern;
    
    public class Preprocess {
        private static final Pattern NUMERIC_PATTERN = Pattern.compile("^[.\d]+$");
        private static final Pattern ENGLISH_WORD_PATTERN = Pattern.compile("^[a-z]+$");
    
        public static void main(String[] args) {
            String inPath1 = "D:\MyData\XUGP3\Desktop\测试分词\test1.txt";
            String inPath2 = "D:\MyData\XUGP3\Desktop\测试分词\stop_words.txt";
            String outPath = "D:\MyData\XUGP3\Desktop\测试分词\result1.txt";
            String encoding = "utf-8";
    
            PrintWriter writer = null;
            Forest forest = null;
            try {
                writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(outPath), encoding));
                forest = Library.makeForest(Test.class.getResourceAsStream("/library/userLibrary.dic"));
    
                List<String> lineList = IOUtil.readLines(new FileInputStream(inPath1), encoding);
                List<String> stopWordList = IOUtil.readLines(new FileInputStream(inPath2), encoding);
    
                for (String line : lineList) {
                    String[] cols = line.split("\t", -1);
    
                    if (cols.length < 2) {
                        continue;
                    }
    
                    String text = cols[0].trim().toLowerCase() + " " + cols[1].trim().toLowerCase();
    
                    // 分词
                    List<Term> termList = ToAnalysis.parse(text, forest).getTerms();
                    List<String> wordList = new LinkedList<>();
                    for (Term term : termList) {
                        String word = term.getName();
    
                        if (word.length() < 2) {
                            continue;
                        }
    
                        if (stopWordList.contains(word)) {
                            continue;
                        }
    
                        if (isNumeric(word)) {
                            continue;
                        }
    
                        if (isEnglishWord(word)) {
                            continue;
                        }
    
                        wordList.add(word);
                    }
    
                    if (wordList.size() > 5) {
                        String outStr = listToLine(wordList);
                        writer.println(outStr);
                    }
                }
            } catch (FileNotFoundException e) {
                System.out.println("The file does not exist or the path is not correct!!!");
                System.exit(-1);
            } catch (UnsupportedEncodingException e) {
                System.out.println("Does not support the current character set!!!");
            } catch (IOException e) {
                e.printStackTrace();
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                if (writer != null) {
                    writer.close();
                }
            }
        }
    
        private static boolean isNumeric(String text) {
            return NUMERIC_PATTERN.matcher(text).matches();
        }
    
        private static boolean isEnglishWord(String text) {
            return ENGLISH_WORD_PATTERN.matcher(text).matches();
        }
    
        private static String listToLine(List<String> list) {
            StringBuilder sb = new StringBuilder();
            for (int i=0; i<list.size(); i++) {
                sb.append(list.get(i));
                if (i != list.size()-1) {
                    sb.append(" ");
                }
            }
            return sb.toString();
        }
    }
    

    2 训练

    训练的代码非常简单,可以直接看官网的教程,至于 word2vec 的原理可以看皮提果的博文。

    package ai.mole.test;
    
    import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
    import org.deeplearning4j.models.word2vec.Word2Vec;
    import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
    import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
    import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
    import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    import java.io.File;
    import java.io.IOException;
    import java.util.Collection;
    
    public class TrainWord2VecModel {
        private static Logger log = LoggerFactory.getLogger(TrainWord2VecModel.class);
    
        public static void main(String[] args) throws IOException {
            String corpusPath = "/data/analyze/xgp/words.txt";
            String vectorsPath = "/data/analyze/xgp/word_vectors.txt";
    
            log.info("Start Training...");
            long st = System.currentTimeMillis();
    
            log.info("Load & vectorize sentences...");
            SentenceIterator iter = new BasicLineIterator(new File(corpusPath));
            TokenizerFactory t = new DefaultTokenizerFactory();
    //        t.setTokenPreProcessor(new CommonPreprocessor());
    
            log.info("Building model...");
            Word2Vec vec = new Word2Vec.Builder()
                    .minWordFrequency(50)
                    .iterations(1)
                    .epochs(100)
                    .layerSize(500)
                    .seed(42)
                    .windowSize(5)
                    .iterate(iter)
                    .tokenizerFactory(t)
                    .build();
    
            log.info("Fitting word2vec model...");
            vec.fit();
    
            log.info("Writing word vectors to text file...");
    //        WordVectorSerializer.writeWord2VecModel(vec, vectorsPath);
            WordVectorSerializer.writeWordVectors(vec, vectorsPath);
    
            log.info("Closest words:");
            Collection<String> bydWordList = vec.wordsNearest("比亚迪", 10);
            Collection<String> changanWordList = vec.wordsNearest("长安", 10);
            System.out.print(bydWordList);
            System.out.println(changanWordList);
    
            log.info("10 words closest to '比亚迪': {}", bydWordList);
            log.info("10 words closest to '长安': {}", changanWordList);
    
            long et = System.currentTimeMillis();
            log.info("Training is completed, and the time taken is " + (et-st) + " ms.");
            System.out.println("Training is completed, and the time taken is " + (et-st) + " ms.");
        }
    }
    

    3 调用

    调用训练好的词向量也非常简单,只需要调用 WordVectorSerializer 类的静态方法 readWord2VecModel 就可以了,提供的输入参数就是训练好的词向量路径。

    Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel("D:\MyData\XUGP3\Desktop\测试分词\vectors.txt");
    Collection<String> bydWordList = word2Vec.wordsNearest("比亚迪", 10);
    Collection<String> changanWordList = word2Vec.wordsNearest("长安", 10);
    System.out.println(bydWordList);
    System.out.println(changanWordList);
    

    附录 - maven 依赖

    <dependencies>
        <dependency>
            <groupId>org.apdplat</groupId>
            <artifactId>word</artifactId>
            <version>1.3</version>
        </dependency>
    
        <!-- ND4J backend. You need one in every DL4J project. Normally define artifactId as either "nd4j-native-platform" or "nd4j-cuda-7.5-platform" -->
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>${nd4j.backend}</artifactId>
            <version>${nd4j.version}</version>
        </dependency>
    
        <!-- Core DL4J functionality -->
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
    
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-nlp</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
    
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-zoo</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
    
        <!-- deeplearning4j-ui is used for visualization: see http://deeplearning4j.org/visualization -->
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-ui_${scala.binary.version}</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
    
        <!-- ParallelWrapper & ParallelInference live here -->
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-parallel-wrapper_${scala.binary.version}</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
    
        <!-- Next 2: used for MapFileConversion Example. Note you need *both* together -->
        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-hadoop</artifactId>
            <version>${datavec.version}</version>
        </dependency>
    
        <dependency>
            <groupId>org.apache.hadoop</groupId>
            <artifactId>hadoop-common</artifactId>
            <version>${hadoop.version}</version>
        </dependency>
    
    
        <!-- Arbiter - used for hyperparameter optimization (grid/random search) -->
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>arbiter-deeplearning4j</artifactId>
            <version>${arbiter.version}</version>
        </dependency>
        
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>arbiter-ui_2.11</artifactId>
            <version>${arbiter.version}</version>
        </dependency>
    
        <!-- datavec-data-codec: used only in video example for loading video data -->
        <dependency>
            <artifactId>datavec-data-codec</artifactId>
            <groupId>org.datavec</groupId>
            <version>${datavec.version}</version>
        </dependency>
    </dependencies>
    

  • 相关阅读:
    linux软件安装
    shell脚本
    ssh密钥登录及远程执行命令
    shell编程
    vi编辑器
    linux入门
    《玩转Bootstrap(JS插件篇)》笔记
    SharePoint BI
    Apache-ActiveMQ transport XmlMessage
    C#操作AD及Exchange Server总结(二)
  • 原文地址:https://www.cnblogs.com/xugenpeng/p/9144656.html
Copyright © 2011-2022 走看看