题意
给定一个长度为n字符串,字符集大小为m(1<=n,m<=1e6),求(igoplus_{c = 1}^{m}left(h(c) cdot 3^c mod (10^9+7) ight))的值。其中h(c)为将c加到字符串末尾产生的新的本质不同的子串数目。
解题思路
比赛的时候没做出来,颁奖的时候听lts和lsx讲了之后发现可以用SAM做,而且板子稍微改改就可以了。
具体就是每次添加一个字符最多新建2个节点,根据SAM的性质,添加c后新建节点对本质不同的子串的数目的贡献就是h(c),用len[i]-len[link[i]]就能算出来。
现场赛的时候我一直没想到添加结点之后要怎么把这个结点再删掉,没有立刻想出来,然后就去做D数位DP了,但是其实很简单,可以把extend函数改一改,对于询问只算答案不修改原来的节点,这样就不需要删除;或者用一个容器记录修改了哪一些结点,询问完了再遍历一遍改回去。这两种方法应该都是(O(1))修改的。
代码实现
和叉姐的标程对拍没有出错,应该是对的吧
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e6+5;
const int mod=1e9+7;
struct Suffix_Automaton{
struct state{
int len,link;
map<int,int>next;
}st[maxn<<1];
int last,tot;
int sz;
void init(){
st[1].len=0;st[0].link=0;
st[1].next.clear();
last=tot=1;;
}
int newnode(){
++tot;
st[tot].len=st[tot].link=0;
st[tot].next.clear();
return tot;
}
void extend(int c){
int p=last;
int cur=newnode();
st[cur].len=st[last].len+1;
last=cur;
while(p && !st[p].next.count(c)){
st[p].next[c]=cur;
p=st[p].link;
}
if(!p)st[cur].link=1;
else{
int q=st[p].next[c];
if(st[p].len+1==st[q].len)st[cur].link=q;
else{
int clone=newnode();
st[clone].len=st[p].len+1;
st[clone].next=st[q].next;
st[clone].link=st[q].link;
st[q].link=st[cur].link=clone;
while(st[p].next[c]==q){
st[p].next[c]=clone;
p=st[p].link;
}
}
}
}
ll solve(int c){
ll res=0;
int p=last;
int cur=newnode();
st[cur].len=st[last].len+1;
//last=cur;
while(p && !st[p].next.count(c)){
//st[p].next[c]=cur;
p=st[p].link;
}
if(!p){
st[cur].link=1;
res=st[cur].len-st[st[cur].link].len;
}
else{
int q=st[p].next[c];
if(st[p].len+1==st[q].len){
st[cur].link=q;
res=st[cur].len-st[st[cur].link].len;
}
else{
int clone=newnode();
st[clone].len=st[p].len+1;
//st[clone].next=st[q].next;
st[clone].link=st[q].link;
//while(p && st[p].next[c]==q){
//st[p].next[c]=clone;
//p=st[p].link;
//}
//st[q].link=clone;
st[cur].link=clone;
//结点q的link变为clone,所以需要把原来的删掉再把新的加进去
res=((res-(st[q].len-st[st[q].link].len))%mod+mod)%mod;
res=(res+st[q].len-st[clone].len)%mod;
res=(res+st[cur].len-st[st[cur].link].len)%mod;
res=(res+st[clone].len-st[st[clone].link].len)%mod;
}
}
tot=sz;
return (res%mod+mod)%mod;
}
}S;
int n,m;
int main()
{
while(~scanf("%d %d",&n,&m)){
S.init();
int x;
for(int i=1;i<=n;i++){
scanf("%d",&x);
S.extend(x);
}
S.sz=S.tot;
ll ans=0,three=1,hc;
for(int i=1;i<=m;i++){
three=three*3%mod;
hc=S.solve(i);
ans^=(hc*three%mod);
}
printf("%lld
",ans);
}
return 0;
}