问题背景:
假设数据库中存有搜索词条和对应的搜索频度, 当用户输入某一串字符(不区分大小写)时, 需要输出以输入为前缀的若干搜索频度最大的词条。
下面是基于Trie的算法实现。 基于Trie的实现的好处是查询效率高, 支持动态查询(能快速更新数据库)。
C++ 源代码:
1 #include <iostream> 2 #include <cstring> 3 #include <string> 4 #include <cctype> 5 #include <queue> 6 #include <utility> 7 using namespace std; 8 9 const int n_ascii = 256; 10 11 struct trie_node { 12 trie_node* ptrs[n_ascii]; 13 int freq; 14 trie_node (int f = 0) : freq(f) { 15 memset(ptrs, 0, sizeof(ptrs)); 16 } 17 }; 18 19 class trie { 20 private: 21 trie_node* root; 22 23 public: 24 trie() : root(new trie_node()) {} 25 ~trie() { release(root); } 26 27 void insert(const string& word, int freq) { 28 trie_node* prev = root; 29 for (auto iter = word.begin(); iter != word.end(); ++iter) { 30 int ch = *iter; 31 trie_node* node = prev->ptrs[ch]; 32 if (!node) { 33 prev->ptrs[ch] = node = new trie_node(); 34 } 35 prev = node; 36 } 37 prev->freq = freq; 38 } 39 40 typedef pair<trie_node*, string> search_result_type; 41 inline const vector<search_result_type> search(string prefix, bool case_sensitive = false) { 42 if (case_sensitive) { 43 return search_case_sensitive(prefix); 44 } else { 45 return search_case_insensitive(prefix); 46 } 47 } 48 49 const vector<search_result_type> search_case_sensitive(string prefix) const { 50 trie_node* curr = root; 51 for (auto iter = prefix.begin(); iter != prefix.end(); ++iter) { 52 int ch = *iter; 53 if (curr) curr = curr->ptrs[ch]; 54 else break; 55 } 56 vector<search_result_type> tmp; 57 if (curr) tmp.push_back(make_pair(curr, prefix)); 58 return tmp; 59 } 60 61 inline const vector<search_result_type> search_case_insensitive(string prefix) const { 62 vector<search_result_type> search_result; 63 string prefix2(prefix); 64 aux_search(root, prefix, 0, prefix2, search_result); 65 return search_result; 66 } 67 68 private: 69 trie(const trie& t); 70 const trie& operator=(const trie& t); 71 72 void aux_search(const trie_node* pnode, const string& word, const int i, 73 string& prefix, vector<search_result_type>& search_result) const { 74 if (i == word.size()) { 75 search_result.push_back( 76 make_pair(const_cast<trie_node*>(pnode), prefix)); 77 } 78 else if (pnode) { 79 prefix[i] = word[i]; 80 aux_search(pnode->ptrs[word[i]], word, i+1, prefix, search_result); 81 82 int ch = word[i]; 83 if (isupper(ch)) ch = tolower(ch); 84 else if (islower(ch)) ch = toupper(ch); 85 else ch = 0; 86 if (ch) { 87 prefix[i] = ch; 88 aux_search(pnode->ptrs[ch], word, i+1, prefix, search_result); 89 } 90 } 91 } 92 93 void release(trie_node * root) { 94 if (root) { 95 for (int i = 0; i < n_ascii; i++) { 96 release(root->ptrs[i]); 97 } 98 delete root; 99 } 100 } 101 }; 102 103 104 string trim_head(const string& str) { 105 auto pos = str.find_first_not_of(" "); 106 if (pos == string::npos) 107 return str; 108 return str.substr(pos); 109 } 110 111 string trim_tail(const string& str) { 112 auto pos = str.find_last_not_of(" "); 113 if (pos == string::npos) 114 return str; 115 return str.substr(0, pos+1); 116 } 117 118 typedef pair<string, int> value_type; 119 struct my_comp { 120 bool operator()(const value_type& v1, const value_type& v2) const { 121 if (v1.second != v2.second) { 122 return v1.second > v2.second; 123 } 124 else return v1.first < v2.first; 125 } 126 } comp_obj; 127 priority_queue<value_type, vector<value_type>, my_comp> priq; 128 int max_heap_size = 0; 129 130 void push_heap(const value_type& v) { 131 if (priq.size() < max_heap_size) { 132 priq.push(v); 133 } else if (comp_obj(v, priq.top())) { 134 priq.pop(); 135 priq.push(v); 136 } 137 } 138 139 bool is_leaf(const trie_node* pnode) { 140 if (pnode) { 141 for (int i = 0; i < n_ascii; i++) { 142 if (pnode->ptrs[i] != 0) return false; 143 } 144 return true; 145 } 146 return false; 147 } 148 149 bool is_valid_freq(const trie_node* pnode) { 150 if (pnode) { 151 return pnode->freq != 0; 152 } 153 return false; 154 } 155 156 void traverse(const trie_node* pnode, const string& prefix) { 157 if (pnode) { 158 if (is_valid_freq(pnode) || is_leaf(pnode)) 159 push_heap(make_pair(prefix, pnode->freq)); 160 for (int i = 0; i < n_ascii; i++) { 161 if (pnode->ptrs[i]) { 162 string tmp_str(prefix); 163 traverse(pnode->ptrs[i], tmp_str.append(1, (char)i)); 164 } 165 } 166 } 167 } 168 169 void list_top_n(const trie& tr, const string& prefix, int n) { 170 auto results = tr.search_case_insensitive(prefix); 171 max_heap_size = n; 172 for (auto result : results) { 173 traverse(result.first, result.second); 174 } 175 176 vector<value_type> items; 177 int lim = priq.size(); 178 for (int i = 0; i < lim; i++) { 179 items.push_back(priq.top()); 180 priq.pop(); 181 } 182 183 if (lim) { 184 for (auto iter = items.rbegin(); iter != items.rend(); ++iter) { 185 cout << iter->first << ' ' << iter->second << endl; 186 } 187 } else { 188 cout << "not found !!!" << endl; 189 } 190 } 191 192 int main() { 193 trie tr; 194 string line; 195 while (getline(cin, line)) { 196 string words; 197 auto iter = line.begin(); 198 for (; iter != line.end(); ++iter) { 199 if (!isdigit(*iter)) { 200 words.append(1, (char)(*iter)); 201 } else break; 202 } 203 if (isdigit(*iter)) { 204 words = trim_tail(words); 205 tr.insert(words, atoi(&(*iter))); 206 } else break; 207 } 208 209 string prefix; 210 while (getline(cin, prefix)) { 211 cout << prefix << " : "<< endl; 212 prefix = trim_head(prefix); 213 list_top_n(tr, prefix, 10); 214 cout << endl; 215 } 216 return 0; 217 }
测试:
输入:
Baidu 100 Google 100 Google Map 150 Google Play 200 gfsoso 100 google 250 Go 50 G
输出:
G : google 250 Google Play 200 Google Map 150 Google 100 gfsoso 100 Go 50