zoukankan      html  css  js  c++  java
  • engine.go

    package engine

    import (
        "fmt"
        "github.com/huichen/murmur"
        "github.com/huichen/sego"
        "github.com/huichen/wukong/core"
        "github.com/huichen/wukong/storage"
        "github.com/huichen/wukong/types"
        "github.com/huichen/wukong/utils"
        "log"
        "os"
        "runtime"
        "sort"
        "strconv"
        "sync/atomic"
        "time"
    )

    const (
        NumNanosecondsInAMillisecond = 1000000
        PersistentStorageFilePrefix  = "wukong"
    )

    type Engine struct {
        // 计数器,用来统计有多少文档被索引等信息
        numDocumentsIndexed      uint64
        numDocumentsRemoved      uint64
        numDocumentsForceUpdated uint64
        numIndexingRequests      uint64
        numRemovingRequests      uint64
        numForceUpdatingRequests uint64
        numTokenIndexAdded       uint64
        numDocumentsStored       uint64

        // 记录初始化参数
        initOptions types.EngineInitOptions
        initialized bool

        indexers   []core.Indexer
        rankers    []core.Ranker
        segmenter  sego.Segmenter
        stopTokens StopTokens
        dbs        []storage.Storage

        // 建立索引器使用的通信通道
        segmenterChannel         chan segmenterRequest
        indexerAddDocChannels    []chan indexerAddDocumentRequest
        indexerRemoveDocChannels []chan indexerRemoveDocRequest
        rankerAddDocChannels     []chan rankerAddDocRequest

        // 建立排序器使用的通信通道
        indexerLookupChannels   []chan indexerLookupRequest
        rankerRankChannels      []chan rankerRankRequest
        rankerRemoveDocChannels []chan rankerRemoveDocRequest

        // 建立持久存储使用的通信通道
        persistentStorageIndexDocumentChannels []chan persistentStorageIndexDocumentRequest
        persistentStorageInitChannel           chan bool
    }

    func (engine *Engine) Init(options types.EngineInitOptions) {
        // 将线程数设置为CPU数
        runtime.GOMAXPROCS(runtime.NumCPU())

        // 初始化初始参数
        if engine.initialized {
            log.Fatal("请勿重复初始化引擎")
        }
        options.Init()
        engine.initOptions = options
        engine.initialized = true

        if !options.NotUsingSegmenter {
            // 载入分词器词典
            engine.segmenter.LoadDictionary(options.SegmenterDictionaries)

            // 初始化停用词
            engine.stopTokens.Init(options.StopTokenFile)
        }

        // 初始化索引器和排序器
        for shard := 0; shard < options.NumShards; shard++ {
            engine.indexers = append(engine.indexers, core.Indexer{})
            engine.indexers[shard].Init(*options.IndexerInitOptions)

            engine.rankers = append(engine.rankers, core.Ranker{})
            engine.rankers[shard].Init()
        }

        // 初始化分词器通道
        engine.segmenterChannel = make(
            chan segmenterRequest, options.NumSegmenterThreads)

        // 初始化索引器通道
        engine.indexerAddDocChannels = make(
            []chan indexerAddDocumentRequest, options.NumShards)
        engine.indexerRemoveDocChannels = make(
            []chan indexerRemoveDocRequest, options.NumShards)
        engine.indexerLookupChannels = make(
            []chan indexerLookupRequest, options.NumShards)
        for shard := 0; shard < options.NumShards; shard++ {
            engine.indexerAddDocChannels[shard] = make(
                chan indexerAddDocumentRequest,
                options.IndexerBufferLength)
            engine.indexerRemoveDocChannels[shard] = make(
                chan indexerRemoveDocRequest,
                options.IndexerBufferLength)
            engine.indexerLookupChannels[shard] = make(
                chan indexerLookupRequest,
                options.IndexerBufferLength)
        }

        // 初始化排序器通道
        engine.rankerAddDocChannels = make(
            []chan rankerAddDocRequest, options.NumShards)
        engine.rankerRankChannels = make(
            []chan rankerRankRequest, options.NumShards)
        engine.rankerRemoveDocChannels = make(
            []chan rankerRemoveDocRequest, options.NumShards)
        for shard := 0; shard < options.NumShards; shard++ {
            engine.rankerAddDocChannels[shard] = make(
                chan rankerAddDocRequest,
                options.RankerBufferLength)
            engine.rankerRankChannels[shard] = make(
                chan rankerRankRequest,
                options.RankerBufferLength)
            engine.rankerRemoveDocChannels[shard] = make(
                chan rankerRemoveDocRequest,
                options.RankerBufferLength)
        }

        // 初始化持久化存储通道
        if engine.initOptions.UsePersistentStorage {
            engine.persistentStorageIndexDocumentChannels =
                make([]chan persistentStorageIndexDocumentRequest,
                    engine.initOptions.PersistentStorageShards)
            for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
                engine.persistentStorageIndexDocumentChannels[shard] = make(
                    chan persistentStorageIndexDocumentRequest)
            }
            engine.persistentStorageInitChannel = make(
                chan bool, engine.initOptions.PersistentStorageShards)
        }

        // 启动分词器
        for iThread := 0; iThread < options.NumSegmenterThreads; iThread++ {
            go engine.segmenterWorker()
        }

        // 启动索引器和排序器
        for shard := 0; shard < options.NumShards; shard++ {
            go engine.indexerAddDocumentWorker(shard)
            go engine.indexerRemoveDocWorker(shard)
            go engine.rankerAddDocWorker(shard)
            go engine.rankerRemoveDocWorker(shard)

            for i := 0; i < options.NumIndexerThreadsPerShard; i++ {
                go engine.indexerLookupWorker(shard)
            }
            for i := 0; i < options.NumRankerThreadsPerShard; i++ {
                go engine.rankerRankWorker(shard)
            }
        }

        // 启动持久化存储工作协程
        if engine.initOptions.UsePersistentStorage {
            err := os.MkdirAll(engine.initOptions.PersistentStorageFolder, 0700)
            if err != nil {
                log.Fatal("无法创建目录", engine.initOptions.PersistentStorageFolder)
            }

            // 打开或者创建数据库
            engine.dbs = make([]storage.Storage, engine.initOptions.PersistentStorageShards)
            for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
                dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
                db, err := storage.OpenStorage(dbPath)
                if db == nil || err != nil {
                    log.Fatal("无法打开数据库", dbPath, ": ", err)
                }
                engine.dbs[shard] = db
            }

            // 从数据库中恢复
            for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
                go engine.persistentStorageInitWorker(shard)
            }

            // 等待恢复完成
            for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
                <-engine.persistentStorageInitChannel
            }
            for {
                runtime.Gosched()
                if engine.numIndexingRequests == engine.numDocumentsIndexed {
                    break
                }
            }

            // 关闭并重新打开数据库
            for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
                engine.dbs[shard].Close()
                dbPath := engine.initOptions.PersistentStorageFolder + "/" + PersistentStorageFilePrefix + "." + strconv.Itoa(shard)
                db, err := storage.OpenStorage(dbPath)
                if db == nil || err != nil {
                    log.Fatal("无法打开数据库", dbPath, ": ", err)
                }
                engine.dbs[shard] = db
            }

            for shard := 0; shard < engine.initOptions.PersistentStorageShards; shard++ {
                go engine.persistentStorageIndexDocumentWorker(shard)
            }
        }

        atomic.AddUint64(&engine.numDocumentsStored, engine.numIndexingRequests)
    }

    // 将文档加入索引
    //
    // 输入参数:
    //  docId          标识文档编号,必须唯一,docId == 0 表示非法文档(用于强制刷新索引),[1, +oo) 表示合法文档
    //  data          见DocumentIndexData注释
    //  forceUpdate 是否强制刷新 cache,如果设为 true,则尽快添加到索引,否则等待 cache 满之后一次全量添加
    //
    // 注意:
    //      1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
    //      2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
    //         如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
    func (engine *Engine) IndexDocument(docId uint64, data types.DocumentIndexData, forceUpdate bool) {
        engine.internalIndexDocument(docId, data, forceUpdate)

        hash := murmur.Murmur3([]byte(fmt.Sprint("%d", docId))) % uint32(engine.initOptions.PersistentStorageShards)
        if engine.initOptions.UsePersistentStorage && docId != 0 {
            engine.persistentStorageIndexDocumentChannels[hash] <- persistentStorageIndexDocumentRequest{docId: docId, data: data}
        }
    }

    func (engine *Engine) internalIndexDocument(
        docId uint64, data types.DocumentIndexData, forceUpdate bool) {
        if !engine.initialized {
            log.Fatal("必须先初始化引擎")
        }

        if docId != 0 {
            atomic.AddUint64(&engine.numIndexingRequests, 1)
        }
        if forceUpdate {
            atomic.AddUint64(&engine.numForceUpdatingRequests, 1)
        }
        hash := murmur.Murmur3([]byte(fmt.Sprint("%d%s", docId, data.Content)))
        engine.segmenterChannel <- segmenterRequest{
            docId: docId, hash: hash, data: data, forceUpdate: forceUpdate}
    }

    // 将文档从索引中删除
    //
    // 输入参数:
    //  docId          标识文档编号,必须唯一,docId == 0 表示非法文档(用于强制刷新索引),[1, +oo) 表示合法文档
    //  forceUpdate 是否强制刷新 cache,如果设为 true,则尽快删除索引,否则等待 cache 满之后一次全量删除
    //
    // 注意:
    //      1. 这个函数是线程安全的,请尽可能并发调用以提高索引速度
    //      2. 这个函数调用是非同步的,也就是说在函数返回时有可能文档还没有加入索引中,因此
    //         如果立刻调用Search可能无法查询到这个文档。强制刷新索引请调用FlushIndex函数。
    func (engine *Engine) RemoveDocument(docId uint64, forceUpdate bool) {
        if !engine.initialized {
            log.Fatal("必须先初始化引擎")
        }

        if docId != 0 {
            atomic.AddUint64(&engine.numRemovingRequests, 1)
        }
        if forceUpdate {
            atomic.AddUint64(&engine.numForceUpdatingRequests, 1)
        }
        for shard := 0; shard < engine.initOptions.NumShards; shard++ {
            engine.indexerRemoveDocChannels[shard] <- indexerRemoveDocRequest{docId: docId, forceUpdate: forceUpdate}
            if docId == 0 {
                continue
            }
            engine.rankerRemoveDocChannels[shard] <- rankerRemoveDocRequest{docId: docId}
        }

        if engine.initOptions.UsePersistentStorage && docId != 0 {
            // 从数据库中删除
            hash := murmur.Murmur3([]byte(fmt.Sprint("%d", docId))) % uint32(engine.initOptions.PersistentStorageShards)
            go engine.persistentStorageRemoveDocumentWorker(docId, hash)
        }
    }

    // 查找满足搜索条件的文档,此函数线程安全
    func (engine *Engine) Search(request types.SearchRequest) (output types.SearchResponse) {
        if !engine.initialized {
            log.Fatal("必须先初始化引擎")
        }

        var rankOptions types.RankOptions
        if request.RankOptions == nil {
            rankOptions = *engine.initOptions.DefaultRankOptions
        } else {
            rankOptions = *request.RankOptions
        }
        if rankOptions.ScoringCriteria == nil {
            rankOptions.ScoringCriteria = engine.initOptions.DefaultRankOptions.ScoringCriteria
        }

        // 收集关键词
        tokens := []string{}
        if request.Text != "" {
            querySegments := engine.segmenter.Segment([]byte(request.Text))
            for _, s := range querySegments {
                token := s.Token().Text()
                if !engine.stopTokens.IsStopToken(token) {
                    tokens = append(tokens, s.Token().Text())
                }
            }
        } else {
            for _, t := range request.Tokens {
                tokens = append(tokens, t)
            }
        }

        // 建立排序器返回的通信通道
        rankerReturnChannel := make(
            chan rankerReturnRequest, engine.initOptions.NumShards)

        // 生成查找请求
        lookupRequest := indexerLookupRequest{
            countDocsOnly:       request.CountDocsOnly,
            tokens:              tokens,
            labels:              request.Labels,
            docIds:              request.DocIds,
            options:             rankOptions,
            rankerReturnChannel: rankerReturnChannel,
            orderless:           request.Orderless,
        }

        // 向索引器发送查找请求
        for shard := 0; shard < engine.initOptions.NumShards; shard++ {
            engine.indexerLookupChannels[shard] <- lookupRequest
        }

        // 从通信通道读取排序器的输出
        numDocs := 0
        rankOutput := types.ScoredDocuments{}
        timeout := request.Timeout
        isTimeout := false
        if timeout <= 0 {
            // 不设置超时
            for shard := 0; shard < engine.initOptions.NumShards; shard++ {
                rankerOutput := <-rankerReturnChannel
                if !request.CountDocsOnly {
                    for _, doc := range rankerOutput.docs {
                        rankOutput = append(rankOutput, doc)
                    }
                }
                numDocs += rankerOutput.numDocs
            }
        } else {
            // 设置超时
            deadline := time.Now().Add(time.Nanosecond * time.Duration(NumNanosecondsInAMillisecond*request.Timeout))
            for shard := 0; shard < engine.initOptions.NumShards; shard++ {
                select {
                case rankerOutput := <-rankerReturnChannel:
                    if !request.CountDocsOnly {
                        for _, doc := range rankerOutput.docs {
                            rankOutput = append(rankOutput, doc)
                        }
                    }
                    numDocs += rankerOutput.numDocs
                case <-time.After(deadline.Sub(time.Now())):
                    isTimeout = true
                    break
                }
            }
        }

        // 再排序
        if !request.CountDocsOnly && !request.Orderless {
            if rankOptions.ReverseOrder {
                sort.Sort(sort.Reverse(rankOutput))
            } else {
                sort.Sort(rankOutput)
            }
        }

        // 准备输出
        output.Tokens = tokens
        // 仅当CountDocsOnly为false时才充填output.Docs
        if !request.CountDocsOnly {
            if request.Orderless {
                // 无序状态无需对Offset截断
                output.Docs = rankOutput
            } else {
                var start, end int
                if rankOptions.MaxOutputs == 0 {
                    start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
                    end = len(rankOutput)
                } else {
                    start = utils.MinInt(rankOptions.OutputOffset, len(rankOutput))
                    end = utils.MinInt(start+rankOptions.MaxOutputs, len(rankOutput))
                }
                output.Docs = rankOutput[start:end]
            }
        }
        output.NumDocs = numDocs
        output.Timeout = isTimeout
        return
    }

    // 阻塞等待直到所有索引添加完毕
    func (engine *Engine) FlushIndex() {
        for {
            runtime.Gosched()
            if engine.numIndexingRequests == engine.numDocumentsIndexed &&
                engine.numRemovingRequests*uint64(engine.initOptions.NumShards) == engine.numDocumentsRemoved &&
                (!engine.initOptions.UsePersistentStorage || engine.numIndexingRequests == engine.numDocumentsStored) {
                // 保证 CHANNEL  REQUESTS 全部被执行完
                break
            }
        }
        // 强制更新,保证其为最后的请求
        engine.IndexDocument(0, types.DocumentIndexData{}, true)
        for {
            runtime.Gosched()
            if engine.numForceUpdatingRequests*uint64(engine.initOptions.NumShards) == engine.numDocumentsForceUpdated {
                return
            }
        }
    }

    // 关闭引擎
    func (engine *Engine) Close() {
        engine.FlushIndex()
        if engine.initOptions.UsePersistentStorage {
            for _, db := range engine.dbs {
                db.Close()
            }
        }
    }

    // 从文本hash得到要分配到的shard
    func (engine *Engine) getShard(hash uint32) int {
        return int(hash - hash/uint32(engine.initOptions.NumShards)*uint32(engine.initOptions.NumShards))
    }

  • 相关阅读:
    bzoj3473 字符串
    洛谷P4070 生成魔咒
    洛谷P3975 弦论
    AT1219 歴史の研究
    课上讲的几个新的技巧
    索引与视图(转载)
    oracle中的分支与循环语句
    Oracle to_date()函数的用法《转载》
    自定义函数的存储过程的区别
    UNION 和 UNION ALL 操作符
  • 原文地址:https://www.cnblogs.com/zhangboyu/p/7461669.html
Copyright © 2011-2022 走看看