zoukankan      html  css  js  c++  java
  • B00014 C++实现的AC自动机

    代码来自:A C++ implementation of the aho corasick pattern search algorithm

    源程序如下:

    /*
    * Copyright (C) 2015 Christopher Gilbert.
    *
    * Permission is hereby granted, free of charge, to any person obtaining a copy
    * of this software and associated documentation files (the "Software"), to deal
    * in the Software without restriction, including without limitation the rights
    * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    * copies of the Software, and to permit persons to whom the Software is
    * furnished to do so, subject to the following conditions:
    *
    * The above copyright notice and this permission notice shall be included in all
    * copies or substantial portions of the Software.
    *
    * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    * SOFTWARE.
    */
    
    #ifndef AHO_CORASICK_HPP
    #define AHO_CORASICK_HPP
    
    #include <algorithm>
    #include <cctype>
    #include <map>
    #include <memory>
    #include <set>
    #include <string>
    #include <queue>
    #include <vector>
    
    namespace aho_corasick {
    
    	// class interval
    	class interval {
    		size_t d_start;
    		size_t d_end;
    
    	public:
    		interval(size_t start, size_t end)
    			: d_start(start)
    			, d_end(end) {}
    
    		size_t get_start() const { return d_start; }
    		size_t get_end() const { return d_end; }
    		size_t size() const { return d_end - d_start + 1; }
    
    		bool overlaps_with(const interval& other) const {
    			return d_start <= other.d_end && d_end >= other.d_start;
    		}
    
    		bool overlaps_with(size_t point) const {
    			return d_start <= point && point <= d_end;
    		}
    
    		bool operator <(const interval& other) const {
    			return get_start() < other.get_start();
    		}
    
    		bool operator !=(const interval& other) const {
    			return get_start() != other.get_start() || get_end() != other.get_end();
    		}
    
    		bool operator ==(const interval& other) const {
    			return get_start() == other.get_start() && get_end() == other.get_end();
    		}
    	};
    
    	// class interval_tree
    	template<typename T>
    	class interval_tree {
    	public:
    		using interval_collection = std::vector<T>;
    		
    	private:
    		// class node
    		class node {
    			enum direction {
    				LEFT, RIGHT
    			};
    			using node_ptr = std::unique_ptr<node>;
    
    			size_t              d_point;
    			node_ptr            d_left;
    			node_ptr            d_right;
    			interval_collection d_intervals;
    
    		public:
    			node(const interval_collection& intervals)
    				: d_point(0)
    				, d_left(nullptr)
    				, d_right(nullptr)
    				, d_intervals()
    			{
    				d_point = determine_median(intervals);
    				interval_collection to_left, to_right;
    				for (const auto& i : intervals) {
    					if (i.get_end() < d_point) {
    						to_left.push_back(i);
    					} else if (i.get_start() > d_point) {
    						to_right.push_back(i);
    					} else {
    						d_intervals.push_back(i);
    					}
    				}
    				if (to_left.size() > 0) {
    					d_left.reset(new node(to_left));
    				}
    				if (to_right.size() > 0) {
    					d_right.reset(new node(to_right));
    				}
    			}
    
    			size_t determine_median(const interval_collection& intervals) const {
    				size_t start = -1;
    				size_t end = -1;
    				for (const auto& i : intervals) {
    					size_t cur_start = i.get_start();
    					size_t cur_end = i.get_end();
    					if (start == -1 || cur_start < start) {
    						start = cur_start;
    					}
    					if (end == -1 || cur_end > end) {
    						end = cur_end;
    					}
    				}
    				return (start + end) / 2;
    			}
    
    			interval_collection find_overlaps(const T& i) {
    				interval_collection overlaps;
    				if (d_point < i.get_start()) {
    					add_to_overlaps(i, overlaps, find_overlapping_ranges(d_right, i));
    					add_to_overlaps(i, overlaps, check_right_overlaps(i));
    				} else if (d_point > i.get_end()) {
    					add_to_overlaps(i, overlaps, find_overlapping_ranges(d_left, i));
    					add_to_overlaps(i, overlaps, check_left_overlaps(i));
    				} else {
    					add_to_overlaps(i, overlaps, d_intervals);
    					add_to_overlaps(i, overlaps, find_overlapping_ranges(d_left, i));
    					add_to_overlaps(i, overlaps, find_overlapping_ranges(d_right, i));
    				}
    				return interval_collection(overlaps);
    			}
    
    		protected:
    			void add_to_overlaps(const T& i, interval_collection& overlaps, interval_collection new_overlaps) const {
    				for (const auto& cur : new_overlaps) {
    					if (cur != i) {
    						overlaps.push_back(cur);
    					}
    				}
    			}
    
    			interval_collection check_left_overlaps(const T& i) const {
    				return interval_collection(check_overlaps(i, LEFT));
    			}
    
    			interval_collection check_right_overlaps(const T& i) const {
    				return interval_collection(check_overlaps(i, RIGHT));
    			}
    
    			interval_collection check_overlaps(const T& i, direction d) const {
    				interval_collection overlaps;
    				for (const auto& cur : d_intervals) {
    					switch (d) {
    					case LEFT:
    						if (cur.get_start() <= i.get_end()) {
    							overlaps.push_back(cur);
    						}
    						break;
    					case RIGHT:
    						if (cur.get_end() >= i.get_start()) {
    							overlaps.push_back(cur);
    						}
    						break;
    					}
    				}
    				return interval_collection(overlaps);
    			}
    
    			interval_collection find_overlapping_ranges(node_ptr& node, const T& i) const {
    				if (node) {
    					return interval_collection(node->find_overlaps(i));
    				}
    				return interval_collection();
    			}
    		};
    		node d_root;
    
    	public:
    		interval_tree(const interval_collection& intervals)
    			: d_root(intervals) {}
    
    		interval_collection remove_overlaps(const interval_collection& intervals) {
    			interval_collection result(intervals.begin(), intervals.end());
    			std::sort(result.begin(), result.end(), [](const T& a, const T& b) -> bool {
    				if (b.size() - a.size() == 0) {
    					return a.get_start() > b.get_start();
    				}
    				return a.size() > b.size();
    			});
    			std::set<T> remove_tmp;
    			for (const auto& i : result) {
    				if (remove_tmp.find(i) != remove_tmp.end()) {
    					continue;
    				}
    				auto overlaps = find_overlaps(i);
    				for (const auto& overlap : overlaps) {
    					remove_tmp.insert(overlap);
    				}
    			}
    			for (const auto& i : remove_tmp) {
    				result.erase(
    					std::find(result.begin(), result.end(), i)
    				);
    			}
    			std::sort(result.begin(), result.end(), [](const T& a, const T& b) -> bool {
    				return a.get_start() < b.get_start();
    			});
    			return interval_collection(result);
    		}
    
    		interval_collection find_overlaps(const T& i) {
    			return interval_collection(d_root.find_overlaps(i));
    		}
    	};
    
    	// class emit
    	template<typename CharType>
    	class emit: public interval {
    	public:
    		typedef std::basic_string<CharType>  string_type;
    		typedef std::basic_string<CharType>& string_ref_type;
    
    	private:
    		string_type d_keyword;
    
    	public:
    		emit()
    			: interval(-1, -1)
    			, d_keyword() {}
    
    		emit(size_t start, size_t end, string_type keyword)
    			: interval(start, end)
    			, d_keyword(keyword) {}
    
    		string_type get_keyword() const { return string_type(d_keyword); }
    		bool is_empty() const { return (get_start() == -1 && get_end() == -1); }
    	};
    
    	// class token
    	template<typename CharType>
    	class token {
    	public:
    		enum token_type{
    			TYPE_FRAGMENT,
    			TYPE_MATCH,
    		};
    
    		using string_type     = std::basic_string<CharType>;
    		using string_ref_type = std::basic_string<CharType>&;
    		using emit_type       = emit<CharType>;
    
    	private:
    		token_type  d_type;
    		string_type d_fragment;
    		emit_type   d_emit;
    
    	public:
    		token(string_ref_type fragment)
    			: d_type(TYPE_FRAGMENT)
    			, d_fragment(fragment)
    			, d_emit() {}
    
    		token(string_ref_type fragment, const emit_type& e)
    			: d_type(TYPE_MATCH)
    			, d_fragment(fragment)
    			, d_emit(e) {}
    
    		bool is_match() const { return (d_type == TYPE_MATCH); }
    		string_type get_fragment() const { return string_type(d_fragment); }
    		emit_type get_emit() const { return d_emit; }
    	};
    
    	// class state
    	template<typename CharType>
    	class state {
    	public:
    		typedef state<CharType>*                 ptr;
    		typedef std::unique_ptr<state<CharType>> unique_ptr;
    		typedef std::basic_string<CharType>      string_type;
    		typedef std::basic_string<CharType>&     string_ref_type;
    		typedef std::set<string_type>            string_collection;
    		typedef std::vector<ptr>                 state_collection;
    		typedef std::vector<CharType>            transition_collection;
    
    	private:
    		size_t                         d_depth;
    		ptr                            d_root;
    		std::map<CharType, unique_ptr> d_success;
    		ptr                            d_failure;
    		string_collection              d_emits;
    
    	public:
    		state(): state(0) {}
    
    		state(size_t depth)
    			: d_depth(depth)
    			, d_root(depth == 0 ? this : nullptr)
    			, d_success()
    			, d_failure(nullptr)
    			, d_emits() {}
    
    		ptr next_state(CharType character) const {
    			return next_state(character, false);
    		}
    
    		ptr next_state_ignore_root_state(CharType character) const {
    			return next_state(character, true);
    		}
    
    		ptr add_state(CharType character) {
    			auto next = next_state_ignore_root_state(character);
    			if (next == nullptr) {
    				next = new state<CharType>(d_depth + 1);
    				d_success[character].reset(next);
    			}
    			return next;
    		}
    
    		size_t get_depth() const { return d_depth; }
    
    		void add_emit(string_ref_type keyword) {
    			d_emits.insert(keyword);
    		}
    
    		void add_emit(const string_collection& emits) {
    			for (const auto& e : emits) {
    				string_type str(e);
    				add_emit(str);
    			}
    		}
    
    		string_collection get_emits() const { return d_emits; }
    
    		ptr failure() const { return d_failure; }
    
    		void set_failure(ptr fail_state) { d_failure = fail_state; }
    
    		state_collection get_states() const {
    			state_collection result;
    			for (auto it = d_success.cbegin(); it != d_success.cend(); ++it) {
    				result.push_back(it->second.get());
    			}
    			return state_collection(result);
    		}
    
    		transition_collection get_transitions() const {
    			transition_collection result;
    			for (auto it = d_success.cbegin(); it != d_success.cend(); ++it) {
    				result.push_back(it->first);
    			}
    			return transition_collection(result);
    		}
    
    	private:
    		ptr next_state(CharType character, bool ignore_root_state) const {
    			ptr result = nullptr;
    			auto found = d_success.find(character);
    			if (found != d_success.end()) {
    				result = found->second.get();
    			} else if (!ignore_root_state && d_root != nullptr) {
    				result = d_root;
    			}
    			return result;
    		}
    	};
    
    	template<typename CharType>
    	class basic_trie {
    	public:
    		using string_type = std::basic_string < CharType > ;
    		using string_ref_type = std::basic_string<CharType>&;
    
    		typedef state<CharType>         state_type;
    		typedef state<CharType>*        state_ptr_type;
    		typedef token<CharType>         token_type;
    		typedef emit<CharType>          emit_type;
    		typedef std::vector<token_type> token_collection;
    		typedef std::vector<emit_type>  emit_collection;
    
    		class config {
    			bool d_allow_overlaps;
    			bool d_only_whole_words;
    			bool d_case_insensitive;
    
    		public:
    			config()
    				: d_allow_overlaps(true)
    				, d_only_whole_words(false)
    				, d_case_insensitive(false) {}
    
    			bool is_allow_overlaps() const { return d_allow_overlaps; }
    			void set_allow_overlaps(bool val) { d_allow_overlaps = val; }
    
    			bool is_only_whole_words() const { return d_only_whole_words; }
    			void set_only_whole_words(bool val) { d_only_whole_words = val; }
    
    			bool is_case_insensitive() const { return d_case_insensitive; }
    			void set_case_insensitive(bool val) { d_case_insensitive = val; }
    		};
    
    	private:
    		std::unique_ptr<state_type> d_root;
    		config                      d_config;
    		bool                        d_constructed_failure_states;
    
    	public:
    		basic_trie(): basic_trie(config()) {}
    
    		basic_trie(const config& c)
    			: d_root(new state_type())
    			, d_config(c)
    			, d_constructed_failure_states(false) {}
    
    		basic_trie& case_insensitive() {
    			d_config.set_case_insensitive(true);
    			return (*this);
    		}
    
    		basic_trie& remove_overlaps() {
    			d_config.set_allow_overlaps(false);
    			return (*this);
    		}
    
    		basic_trie& only_whole_words() {
    			d_config.set_only_whole_words(true);
    			return (*this);
    		}
    
    		void insert(string_type keyword) {
    			if (keyword.empty())
    				return;
    			state_ptr_type cur_state = d_root.get();
    			for (const auto& ch : keyword) {
    				cur_state = cur_state->add_state(ch);
    			}
    			cur_state->add_emit(keyword);
    		}
    
    		template<class InputIterator>
    		void insert(InputIterator first, InputIterator last) {
    			for (InputIterator it = first; first != last; ++it) {
    				insert(*it);
    			}
    		}
    
    		token_collection tokenise(string_type text) {
    			token_collection tokens;
    			auto collected_emits = parse_text(text);
    			size_t last_pos = -1;
    			for (const auto& e : collected_emits) {
    				if (e.get_start() - last_pos > 1) {
    					tokens.push_back(create_fragment(e, text, last_pos));
    				}
    				tokens.push_back(create_match(e, text));
    				last_pos = e.get_end();
    			}
    			if (text.size() - last_pos > 1) {
    				tokens.push_back(create_fragment(typename token_type::emit_type(), text, last_pos));
    			}
    			return token_collection(tokens);
    		}
    
    		emit_collection parse_text(string_type text) {
    			check_construct_failure_states();
    			size_t pos = 0;
    			state_ptr_type cur_state = d_root.get();
    			emit_collection collected_emits;
    			for (auto c : text) {
    				if (d_config.is_case_insensitive()) {
    					c = std::tolower(c);
    				}
    				cur_state = get_state(cur_state, c);
    				store_emits(pos, cur_state, collected_emits);
    				pos++;
    			}
    			if (d_config.is_only_whole_words()) {
    				remove_partial_matches(text, collected_emits);
    			}
    			if (!d_config.is_allow_overlaps()) {
    				interval_tree<emit_type> tree(typename interval_tree<emit_type>::interval_collection(collected_emits.begin(), collected_emits.end()));
    				auto tmp = tree.remove_overlaps(collected_emits);
    				collected_emits.swap(tmp);
    			}
    			return emit_collection(collected_emits);
    		}
    
    	private:
    		token_type create_fragment(const typename token_type::emit_type& e, string_ref_type text, size_t last_pos) const {
    			auto start = last_pos + 1;
    			auto end = (e.is_empty()) ? text.size() : e.get_start();
    			auto len = end - start;
    			typename token_type::string_type str(text.substr(start, len));
    			return token_type(str);
    		}
    
    		token_type create_match(const typename token_type::emit_type& e, string_ref_type text) const {
    			auto start = e.get_start();
    			auto end = e.get_end() + 1;
    			auto len = end - start;
    			typename token_type::string_type str(text.substr(start, len));
    			return token_type(str, e);
    		}
    
    		void remove_partial_matches(string_ref_type search_text, emit_collection& collected_emits) const {
    			size_t size = search_text.size();
    			emit_collection remove_emits;
    			for (const auto& e : collected_emits) {
    				if ((e.get_start() == 0 || !std::isalpha(search_text.at(e.get_start() - 1))) &&
    					(e.get_end() + 1 == size || !std::isalpha(search_text.at(e.get_end() + 1)))
    					) {
    					continue;
    				}
    				remove_emits.push_back(e);
    			}
    			for (auto& e : remove_emits) {
    				collected_emits.erase(
    					std::find(collected_emits.begin(), collected_emits.end(), e)
    					);
    			}
    		}
    
    		state_ptr_type get_state(state_ptr_type cur_state, CharType c) const {
    			state_ptr_type result = cur_state->next_state(c);
    			while (result == nullptr) {
    				cur_state = cur_state->failure();
    				result = cur_state->next_state(c);
    			}
    			return result;
    		}
    
    		void check_construct_failure_states() {
    			if (!d_constructed_failure_states) {
    				construct_failure_states();
    			}
    		}
    
    		void construct_failure_states() {
    			std::queue<state_ptr_type> q;
    			for (auto& depth_one_state : d_root->get_states()) {
    				depth_one_state->set_failure(d_root.get());
    				q.push(depth_one_state);
    			}
    			d_constructed_failure_states = true;
    
    			while (!q.empty()) {
    				auto cur_state = q.front();
    				for (const auto& transition : cur_state->get_transitions()) {
    					state_ptr_type target_state = cur_state->next_state(transition);
    					q.push(target_state);
    
    					state_ptr_type trace_failure_state = cur_state->failure();
    					while (trace_failure_state->next_state(transition) == nullptr) {
    						trace_failure_state = trace_failure_state->failure();
    					}
    					state_ptr_type new_failure_state = trace_failure_state->next_state(transition);
    					target_state->set_failure(new_failure_state);
    					target_state->add_emit(new_failure_state->get_emits());
    				}
    				q.pop();
    			}
    		}
    
    		void store_emits(size_t pos, state_ptr_type cur_state, emit_collection& collected_emits) const {
    			auto emits = cur_state->get_emits();
    			if (!emits.empty()) {
    				for (const auto& str : emits) {
    					auto emit_str = typename emit_type::string_type(str);
    					collected_emits.push_back(emit_type(pos - emit_str.size() + 1, pos, emit_str));
    				}
    			}
    		}
    	};
    
    	typedef basic_trie<char>     trie;
    	typedef basic_trie<wchar_t>  wtrie;
    
    
    } // namespace aho_corasick
    
    #endif // AHO_CORASICK_HPP


  • 相关阅读:
    23. 霍纳法则(多项式求值快速算法)
    22. 欧几里德算法(求最大公约数GCD)
    [poj 2106] Boolean Expressions 递归
    [poj 1185] 炮兵阵地 状压dp 位运算
    [MOOC程序设计与算法二] 递归二
    [poj 3254] Corn Fields 状压dp
    [hdu 1074] Doing Homework 状压dp
    [hdu 1568] Fibonacci数列前4位
    [haut] 1281: 邪能炸弹 dp
    [hdu 2604] Queuing 递推 矩阵快速幂
  • 原文地址:https://www.cnblogs.com/tigerisland/p/7564720.html
Copyright © 2011-2022 走看看