题意:给出多个串s1…sn,每个串si对应happy值hi,对于某个串t,如果它在串si,sj,sk..等串中出现,那么串t的happy值为hi * hj * hk * …,询问一个m,输出长度<=m的随机串的happy值期望。
先来一种暴力做法,遍历s1…sn的所有子串,对于某个子串t,算出它对ans(len(t))的贡献,不重复地累加所有子串贡献,最后取前缀和即可。
回顾下sam的每个状态(点),它代表了多个串,这多个串的结束位置集相同,且如果按照长度将他们排序,那它们长度递增且每个串都是这些串中长度最大的串的后缀,如bbc,bbcd,bbcdc,这点其实好理解,因为最大长度的那个串在某个位置k结束,那它的所有后缀也都会在k出现,那为什么不是从长度1到长度max呢,因为长度小的串可能在别的地方也出现过,也就是说一个串的出现位置集大小总是小于等于它的后缀的出现位置集大小。
回到这题,再下看广义sam做法,先说建立了广义sam后怎么做,此时每个子串都在sam中,在单个串的sam中,每个点代表的串有相同的结束位置集,那在广义sam里也有相同的结束位置集,且这些串在哪些串中出现也是一样的,所以这些串的happy值是一样的,那一个点的happy值就是累乘出现串集的所有h值。
那如何建立广义sam?对一个串si建立sam时,从初始状态开始转移,当要加入字符c = s[i][j]时,如果发现当前状态st已经有边c指向下一状态v,且v代表的串的最大长度=当前已匹配串长度,那状态v可以代表当前串s[i][0]…s[i][j],否则新建结点(不会证)。
那建完广义sam如何维护happy值?每个串跑一遍,遇到的点都维护上权值,并且要沿着fail边跑回去更新,每个点只更新一次。
接下来只要遍历所有点统计答案就行了,因为一个点代表的多个串的长度是连续的,所以再区间覆盖一下即可。
#include<iostream> #include<cmath> #include<cstring> #include<queue> #include<vector> #include<cstdio> #include<algorithm> #include<map> #include<set> #define rep(i,e) for(int i=0;i<(e);i++) #define rep1(i,e) for(int i=1;i<=(e);i++) #define repx(i,x,e) for(int i=(x);i<=(e);i++) #define pii pair<int,int> #define X first #define Y second #define PB push_back #define MP make_pair #define mset(var,val) memset(var,val,sizeof(var)) #define scd(a) scanf("%d",&a) #define scdd(a,b) scanf("%d%d",&a,&b) #define scddd(a,b,c) scanf("%d%d%d",&a,&b,&c) #define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0) using namespace std; typedef long long ll; template <class T> void test(T a){cout<<a<<endl;} template <class T,class T2> void test(T a,T2 b){cout<<a<<" "<<b<<endl;} template <class T,class T2,class T3> void test(T a,T2 b,T3 c){cout<<a<<" "<<b<<" "<<c<<endl;} const int N = 1e6+10; const int inf = 0x3f3f3f3f; const ll INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9+7; int idx; int maxlen[N], minlen[N], trans[N][26], slink[N]; int new_state(int _maxlen, int _minlen, int* _trans, int _slink) { maxlen[idx] = _maxlen; minlen[idx] = _minlen; for(int i = 0; i < 26; i++) { if(_trans == NULL) trans[idx][i] = -1; else trans[idx][i] = _trans[i]; } slink[idx] = _slink; return idx++; } int val[N]; int add_char(char ch, int u) { int c = ch - 'a'; int z = new_state(maxlen[u] + 1, -1, NULL, -1); while(u != -1 && trans[u][c] == -1) { trans[u][c] = z; u = slink[u]; } if(u == -1) { minlen[z] = 1; slink[z] = 0; return z; } int x = trans[u][c]; if(maxlen[u] + 1 == maxlen[x]) { minlen[z] = maxlen[x] + 1; slink[z] = x; return z; } int y = new_state(maxlen[u] + 1, -1, trans[x], slink[x]); minlen[z] = minlen[x] = maxlen[y] + 1; slink[z] = slink[x] = y; while(u != -1 && trans[u][c] == x) { trans[u][c] = y; u = slink[u]; } minlen[y] = maxlen[slink[y]] + 1; return z; } string s[N]; int h[N]; int wlk[N]; ll qpow(ll a,ll b){ ll ret = 1; while(b){ if(b&1)ret=ret*a%mod; b>>=1; a=a*a%mod; } return ret; } ll INV(ll a){ return qpow(a,mod-2); } int ans[N]; int vis[N]; void work(){ int n;cin>>n; int mxlen = 0; rep(i,n){ cin>>s[i]; mxlen = max(mxlen, (int)s[i].size()); } new_state(0,0,NULL,-1); rep(i,n){ cin>>h[i]; int st=0; rep(j,s[i].size()){ int nx = trans[st][s[i][j]-'a']; if(nx == -1 || maxlen[nx] != j+1){//建广义sam,两种情况 st=add_char(s[i][j], st); }else{ st = nx; } } } rep(i,n){ int st=0; rep(j,s[i].size()){ st = trans[st][s[i][j]-'a']; int u = st; while(u!=-1){//遍历所有和该串有关的点 if(vis[u]==i+1) break; // 不memset的vis vis[u]=i+1; if(!val[u]) val[u]=h[i]; else val[u]=1ll*val[u]*h[i]%mod; u=slink[u]; } } } rep1(i,idx-1){ (wlk[minlen[i]] += val[i])%=mod; (wlk[maxlen[i]+1] -= val[i])%=mod;//区间覆盖 } int now = 0; int sum = 0; int div = 0; int pw = 1; rep1(i,mxlen){ pw=pw*26ll%mod; (div+=pw)%=mod; (now += wlk[i])%=mod; (sum+=now)%=mod; ans[i] = 1ll*sum*INV(div)%mod; } int m;cin>>m; rep(i,m){ int x;cin>>x; if(x>mxlen) x = mxlen;//记得特判 test((ans[x]+mod)%mod); } } int main() { #ifdef local freopen("in.txt","r",stdin); #endif // local IOS; work(); }