zoukankan      html  css  js  c++  java
  • 【XSY3320】string AC自动机 哈希 点分治

    题目大意

      给一棵树,每条边上有一个字符,求有多少对 ((x,y)(x<y)),满足 (x)(y) 路径上的边上的字符按顺序组成的字符串为回文串。

      (1leq nleq 50000,1leq x_i,y_ileq n,z_iin{0,1})

    题解

      观察一条经过重心的回文串是长什么样的

      (S) 是一个任意的字符串,(T) 是一个回文串。

      建出根到每个节点对应的串的AC自动机。

      那么 (x) 这边的 (S) 串就是 (x) 对应的AC自动机节点的一个后缀, (T) 串是一个前缀。

      dfs 整棵树的 fail 树,先统计每个点作为 (x) 点的贡献,再把作为 (y) 点的贡献加到数据结构中。

      开 (sqrt n) 个长度为 (sqrt n) 的数组 (c_{1,sqrt n})(c_{i,j}) 表示当前节点有多少个长度 (mod i=j) 的祖先。

      当一个点是 (y) 点的时候,令对应长度的字符串的出现次数 (+1),还要对于 (leq sqrt n) 的所有数 (i),令 (c_{i,lvert S vert mod i}++)

      当一个点是 (x) 点的时候,一个回文串的所有回文前缀可以被表示为 (O(log n)) 个等差数列,公差 (leq sqrt n) 的那部分在 (c) 里面查,剩下的暴力查就好了。

      记一个等差数列的首项为 (a_1),公差为 (d),末项为 (a_n),那么贡献就是 dfs 到深度为 (a_n) 的点时 (c_{d,a_1mod d}) 的值减掉 dfs 到深度为 (a_1-d) 的点时 (c_{d,a_1mod d}) 的值。

      先 dfs 一遍把所有询问的信息插到 vector 中,再 dfs 一遍计算答案。

      求一个串的所有回文前缀可以直接哈希。

      时间复杂度:(f(n)=O(n^frac{3}{2})+O(nlog^2 n)=O(n^frac{3}{2}))

      (T(n)=2T(frac{n}{2})+f(n)=2T(frac{n}{2})+O(n^frac{3}{2})=O(n^frac{3}{2}))

    代码

      把这份代码中的后缀自动机换成 AC自动机,回文自动机换成哈希就好了。

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cstdlib>
    #include<ctime>
    #include<utility>
    #include<functional>
    #include<cmath>
    #include<vector>
    #include<queue>
    #include<assert.h>
    //using namespace std;
    using std::min;
    using std::max;
    using std::swap;
    using std::sort;
    using std::reverse;
    using std::random_shuffle;
    using std::lower_bound;
    using std::upper_bound;
    using std::unique;
    using std::vector;
    using std::queue;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef double db;
    typedef std::pair<int,int> pii;
    typedef std::pair<ll,ll> pll;
    void open(const char *s){
    #ifndef ONLINE_JUDGE
    	char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
    #endif
    }
    void open2(const char *s){
    #ifdef DEBUG
    	char str[100];sprintf(str,"%s.in",s);freopen(str,"r",stdin);sprintf(str,"%s.out",s);freopen(str,"w",stdout);
    #endif
    }
    int rd(){int s=0,c,b=0;while(((c=getchar())<'0'||c>'9')&&c!='-');if(c=='-'){c=getchar();b=1;}do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9');return b?-s:s;}
    void put(int x){if(!x){putchar('0');return;}static int c[20];int t=0;while(x){c[++t]=x%10;x/=10;}while(t)putchar(c[t--]+'0');}
    int upmin(int &a,int b){if(b<a){a=b;return 1;}return 0;}
    int upmax(int &a,int b){if(b>a){a=b;return 1;}return 0;}
    const int N=50010;
    vector<pii> g[N];
    int sz[N];
    int totsz,rt,rtsz;
    int b[N];
    int n;
    int f[N];
    ll* ss[N];
    ll ss2[N];
    ll ans=0;
    int _log[N];
    struct info
    {
    	int x;
    	int y;
    	int z;
    	info(int a=0,int b=0,int c=0):x(a),y(b),z(c){}
    };
    int cmp(info a,info b)
    {
    	if(a.x!=b.x)
    		return a.x<b.x;
    	return a.z<b.z;
    }
    void dfs1(int x,int fa)
    {
    	sz[x]=1;
    	for(auto v:g[x])
    		if(v.first!=fa&&!b[v.first])
    		{
    			dfs1(v.first,x);
    			sz[x]+=sz[v.first];
    		}
    }
    void dfs2(int x,int fa)
    {
    	int mx=totsz-sz[x];
    	for(auto v:g[x])
    		if(v.first!=fa&&!b[v.first])
    		{
    			dfs2(v.first,x);
    			mx=max(mx,sz[v.first]);
    		}
    	if(mx<rtsz)
    	{
    		rtsz=mx;
    		rt=x;
    	}
    }
    void dfs3(int x,int fa)
    {
    	f[x]=fa;
    	for(auto v:g[x])
    		if(v.first!=fa&&!b[v.first])
    			dfs3(v.first,x);
    }
    int tot;
    int str[N];
    namespace sam
    {
    	int next[2*N][2];
    	int fail[2*N];
    	int len[2*N];
    	int last,cnt;
    	int b[2*N];
    	int a[2*N][2];
    	int s[2*N]; 
    	void init()
    	{
    		while(cnt)
    		{
    			next[cnt][0]=next[cnt][1]=0;
    			a[cnt][0]=a[cnt][1]=0;
    			b[cnt]=0;
    			s[cnt]=0;
    			cnt--;
    		}
    		cnt=1;
    		last=1;
    	}
    	int insert(int p,int c)
    	{
    		if(next[p][c])
    		{
    			last=next[p][c];
    			s[last]++;
    			return last;
    		}
    //		int p=last;
    		int np=++cnt;
    		len[np]=len[p]+1;
    		s[np]=1;
    		for(;p&&!next[p][c];p=fail[p])
    			next[p][c]=np;
    		if(!p)
    			fail[np]=1;
    		else
    		{
    			int q=next[p][c];
    			if(len[q]==len[p]+1)
    				fail[np]=q;
    			else
    			{
    				int nq=++cnt;
    				len[nq]=len[p]+1;
    				memcpy(next[nq],next[q],sizeof next[q]);
    				fail[nq]=fail[q];
    				fail[q]=fail[np]=nq;
    				for(;p&&next[p][c]==q;p=fail[p])
    					next[p][c]=nq;
    			}
    		}
    		return last=np;
    	}
    }
    namespace pam
    {
    	int next[N][2];
    	int trans[N][2];
    	int fail[N];
    	int len[N];
    	int diff[N];
    	int link[N];
    	int top[N];
    	int last;
    	int cnt;
    	void init()
    	{
    		while(cnt>=0)
    		{
    			next[cnt][0]=next[cnt][1]=0;
    			trans[cnt][0]=trans[cnt][1]=0;
    			cnt--;
    		}
    		cnt=1;
    		str[0]=-1;
    		fail[0]=1;
    		fail[1]=0;
    		len[0]=0;
    		len[1]=-1;
    		last=0;
    		link[0]=0;
    		diff[0]=1;
    		diff[1]=0;
    		top[0]=0;
    		top[1]=1;
    		trans[0][0]=trans[0][1]=trans[1][0]=trans[1][1]=1;
    	}
    	int find(int x,int c)
    	{
    		return str[tot-len[x]-1]==c?x:trans[x][c];
    	}
    	void insert(int c)
    	{
    		str[++tot]=c;
    		last=find(last,c);
    		int now=last;
    		if(!next[now][c])
    		{
    			int cur=++cnt;
    			len[cur]=len[now]+2;
    			last=find(fail[last],c);
    			fail[cur]=next[last][c];
    			diff[cur]=len[cur]-len[fail[cur]];
    			if(diff[cur]==diff[fail[cur]])
    			{
    				link[cur]=link[fail[cur]];
    				top[cur]=top[fail[cur]];
    			}
    			else
    			{
    				link[cur]=fail[cur];
    				top[cur]=cur;
    			}
    			if(!link[cur])
    				link[cur]=cur;
    			memcpy(trans[cur],trans[fail[cur]],sizeof trans[cur]);
    			trans[cur][str[tot-len[fail[cur]]]]=fail[cur];
    			next[now][c]=cur;
    		}
    		last=next[now][c];
    	}
    }
    namespace trie
    {
    	int a[N][2];
    	int s[N];
    	int cnt;
    	void clear()
    	{
    		while(cnt)
    		{
    			a[cnt][0]=a[cnt][1]=0;
    			s[cnt]=0;
    			cnt--;
    		}
    		cnt=1;
    	}
    }
    ll s,s2;
    int pos[N];
    int pos2[N];
    int pos3[N];
    int pos4[N];
    int q[N];
    int len[N],id[N],top;
    int head,tail;
    vector<int> e[2*N];
    int sq;
    vector<info> h[2*N];
    int orzzjt,orzzjt2;
    void bfs(int x)
    {
    	sam::init();
    //	sam::s[1]=1;
    	pos[x]=1;
    	head=1;
    	tail=0;
    	q[++tail]=x;
    	trie::clear();
    	pos4[x]=1;
    	while(tail>=head)
    	{
    		int y=q[head++];
    		s+=trie::s[pos4[y]];
    		trie::s[pos4[y]]++;
    		for(auto v:g[y])
    			if(!b[v.first]&&v.first!=f[y])
    			{
    				pos[v.first]=sam::insert(pos[y],v.second);
    				q[++tail]=v.first;
    				if(trie::a[pos4[y]][v.second])
    					pos4[v.first]=trie::a[pos4[y]][v.second];
    				else
    					pos4[v.first]=trie::a[pos4[y]][v.second]=++trie::cnt;
    			}
    	}
    }
    void dfs(int x,int fa)
    {
    	for(int y=pos[x];y!=1&&!sam::b[y];y=sam::fail[y])
    	{
    		sam::a[sam::fail[y]][str[tot-sam::len[sam::fail[y]]]]=y;
    		sam::b[y]=1;
    	}
    	//这样建出来的后缀树不是完整的,但已经够用了 
    	
    	int now=pam::last;
    	pos2[x]=now;
    	if(pam::len[now]==tot)
    	{
    		if(fa)
    			s2++;
    		pos3[x]=now;
    	}
    	else
    		pos3[x]=pos3[fa];
    	for(auto v:g[x])
    		if(!b[v.first]&&v.first!=fa)
    		{
    			pam::last=now;
    			pam::insert(v.second);
    			dfs(v.first,x);
    			tot--;
    		}
    }
    void dfs4(int x)
    {
    	len[++top]=sam::len[x];
    	id[top]=x;
    	for(auto v:e[x])
    		for(int y=pos3[v];y>1;)
    			if(pam::diff[y]<=sq)
    			{
    				h[id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[y]-pam::diff[y])-len]].push_back(info(sam::len[x]-pam::len[y]-pam::diff[y],pam::diff[y],-1));
    				h[id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[pam::link[y]])-len]].push_back(info(sam::len[x]-pam::len[pam::link[y]],pam::diff[y],1));
    				//h.push_back(info(sam::len[x]-pam::len[y],id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[y])-len],1));
    //				h.push_back(info(sam::len[x]-pam::len[pam::link[y]]+pam::diff[y],id[lower_bound(len+1,len+top+1,sam::len[x]-pam::len[pam::link[y]]+pam::diff[y])-len],-1));
    				y=pam::fail[pam::link[y]];
    				orzzjt2+=_log[top];
    			}
    			else
    			{
    				y=pam::fail[y];
    			}
    	if(sam::a[x][0])
    		dfs4(sam::a[x][0]);
    	if(sam::a[x][1])
    		dfs4(sam::a[x][1]);
    	top--;
    }
    void dfs5(int x)
    {
    	for(auto v:h[x])
    		if(v.x>=0&&v.x!=sam::len[x])
    			s+=ss[v.y][v.x%v.y]*v.z;
    	orzzjt+=sq;
    	for(int i=1;i<=sq;i++)
    		ss[i][sam::len[x]%i]+=sam::s[x];
    	ss2[sam::len[x]]+=sam::s[x];
    	
    	
    	for(auto v:h[x])
    		if(v.x>=0&&v.x==sam::len[x])
    			s+=ss[v.y][v.x%v.y]*v.z;
    			
    			
    	for(auto v:e[x])
    		for(int y=pos3[v];y>1;)
    			if(pam::diff[y]<=sq)
    			{
    				y=pam::fail[pam::link[y]];
    			}
    			else
    			{
    				s+=ss2[sam::len[x]-pam::len[y]];
    				y=pam::fail[y];
    			}
    			
    	if(sam::a[x][0])
    		dfs5(sam::a[x][0]);
    	if(sam::a[x][1])
    		dfs5(sam::a[x][1]);
    	
    		
    	for(int i=1;i<=sq;i++)
    		ss[i][sam::len[x]%i]-=sam::s[x];
    	ss2[sam::len[x]]-=sam::s[x];
    }
    ll calc(int x)
    {
    	s=0;
    	s2=0;
    	bfs(x);
    	pam::init();
    	dfs(x,0);
    	for(int i=1;i<=sam::cnt;i++)
    	{
    		e[i].clear();
    		h[i].clear();
    	}
    	for(int i=1;i<=tail;i++)
    		e[pos[q[i]]].push_back(q[i]);
    	dfs4(1);
    //	for(int i=1;i<=sam::cnt;i++)
    //		sort(h[i].begin(),h[i].end());
    	dfs5(1);
    	return s;
    }
    int c[N],c2[N];
    int t;
    vector<pii> g2;
    void solve(int x)
    {
    	dfs1(x,0);
    	totsz=sz[x];
    	rtsz=0x7fffffff;
    	dfs2(x,0);
    	x=rt;
    	dfs3(x,0);
    	int t=0;
    	sq=sqrt(totsz);
    //	sq=0;
    	ans+=calc(x);
    	ans+=s2;
    	for(auto v:g[x])
    		if(!b[v.first])
    		{
    			b[v.first]=1;
    			c[++t]=v.first;
    			c2[t]=v.second;
    		}
    	g2=g[x];
    	g[x].clear();
    	for(int i=1;i<=t;i++)
    	{
    		b[c[i]]=0;
    		g[x].clear();
    		g[x].push_back(pii(c[i],c2[i]));
    		ans-=calc(x);
    		b[c[i]]=1;
    	}
    	g[x]=g2;
    	for(int i=1;i<=t;i++)
    		b[c[i]]=0;
    	b[x]=1;
    	for(auto v:g[x])
    		if(!b[v.first])
    			solve(v.first);
    }
    int main()
    {
    	open("string");
    	scanf("%d",&n);
    	for(int i=1;i<=n;i++)
    		for(int j=1,k=0;j<=n;j<<=1,k++)
    			_log[i]=k;
    	int _sqrt=sqrt(n);
    	for(int i=1;i<=_sqrt;i++)
    	{
    		ss[i]=new ll[i];
    		for(int j=0;j<i;j++)
    			ss[i][j]=0;
    	}
    	int x,y,z;
    	for(int i=1;i<n;i++)
    	{
    		scanf("%d%d%d",&x,&y,&z);
    		g[x].push_back(pii(y,z));
    		g[y].push_back(pii(x,z));
    	}
    	solve(1);
    //	assert(ans%2==0);
    //	ans/=2;
    	printf("%lld
    ",ans);
    //	printf("%d
    ",orzzjt);
    //	printf("%d
    ",orzzjt2);
    	return 0;
    }
    
  • 相关阅读:
    c# 设计模式(一) 工厂模式
    微信开发
    一款非常好用的 Windows 服务开发框架,开源项目Topshelf
    基础语法
    C++环境设置
    c++简介
    使用查询分析器和SQLCMD分别登录远程的SQL2005的1434端口
    ps-如何去水印
    html/css/js-横向滚动条的实现
    java中如何给控件设置颜色
  • 原文地址:https://www.cnblogs.com/ywwyww/p/10241224.html
Copyright © 2011-2022 走看看