给定两个 ( exttt{AB?}) 字符串 (c,d) 和正整数 (n),求在所有将 ( exttt ?) 替换为 ( exttt{A/B}) 的方案中,满足 (1le |S|,|T|le n),将 (c,d) 的 ( exttt A) 替换为 (S),将 ( exttt B) 替换为 (T) 使得 (c=d) 的 ( exttt{01}) 字符串对 ((S,T)) 的个数之和(mod(10^9+7))。
(|c|,|d|,nle 3cdot 10^5)
手玩一下,发现当 (c= exttt{AB}),(d= exttt{BA}) 时 (S,T) 都有长为 (gcd(|S|,|T|)) 的整周期。
结论:当 (c e d) 时 (S,T) 都有长为 (gcd(|S|,|T|)) 的整周期。
证明:考虑对 (|S|+|T|) 归纳,显然当 (|S|=|T|=1) 时结论成立。否则不妨设 (|S|le |T|),若 (|S|=|T|) 则显然 (S=T),若 (|S|<|T|) 则把 (c,d) 的 lcp 次掉之后,不妨设 (c_1= exttt A,d_1= exttt B),则 (S) 是 (T) 的前缀,设 (T=S+T'),并将 (c,d) 中的 ( exttt B) 替换为 ( exttt{AB}),此时 (|S|+|T|) 更小了,且 (c_2 e d_2) 所以 (c e d),由归纳假设可知 (S,T') 都有长为 (gcd(|S|,|T'|)),则结论显然成立,得证。
所以 (c,d) 中的顺序无关紧要,设 (a) 是 (c) 的 ( exttt{A}) 个数减去 (d) 的 ( exttt A) 个数,(b) 是 (d) 的 ( exttt B) 个数减去 (a) 的 ( exttt B) 个数。
显然有 (acdot|S|=bcdot|T|),配合上结论即为充要条件。
当 (a=b=0) 时
后面的和式只与 (lfloor n/p floor) 有关,整除分块即可 (O(n))。
当然有一种特殊情况:(c=d) 时答案为 ((sum_{i=1}^n2^i)^2)。
此外,只有当 (ab>0) 时才有解,答案为 (sum_{i=1}^r2^i),其中 (r=lfloorfrac{ngcd(a,b)}{max(a,b)} floor)。
然后要考虑有问号的情况:设 (x,y) 表示 (c,d) 的 ( exttt ?) 数量,(f_{a,b}) 表示上述答案的值,则
那么就做完了,时间复杂度 (O(n+|c|+|d|))。
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 3e5+3, M = 26000, mod = 1e9+7;
int ksm(int a, int b){
int r = 1;
for(;b;b >>= 1, a = (LL)a * a % mod)
if(b & 1) r = (LL)r * a % mod;
return r;
}
void qmo(int &x){x += x >> 31 & mod;}
int n, tot, nc, nd, res, ans, a, b, x, y, fac[N<<1], inv[N<<1], pw[N], mu[N], pri[M];
char c[N], d[N];
bool notp[N];
int calc(int n){
int res = 0;
for(int l = 1, r, x;l <= n;l = r+1){
r = n / (x = n/l);
res = (res + ((LL)mu[r]-mu[l-1]+mod)*x%mod*x)%mod;
} return res;
}
int F(int a, int b){
if(!a && !b) return res;
if((LL)a*b <= 0) return 0;
if(a < 0){a = -a; b = -b;}
return pw[n/(max(a,b)/__gcd(a,b))];
}
int main(){
scanf("%s%s%d", c, d, &n);
nc = strlen(c); nd = strlen(d);
for(int i = 0;i < nc;++ i)
switch(c[i]){
case 'A': ++ a; break;
case 'B': -- b; break;
case '?': ++ x;
}
for(int i = 0;i < nd;++ i)
switch(d[i]){
case 'A': -- a; break;
case 'B': ++ b; break;
case '?': ++ y;
}
pw[1] = 2;
for(int i = 2;i <= n;++ i) qmo(pw[i] = (pw[i-1]<<1) - mod);
for(int i = 2;i <= n;++ i) qmo(pw[i] += pw[i-1] - mod);
fac[0] = mu[1] = 1;
for(int i = 1;i <= x+y;++ i) fac[i] = (LL)fac[i-1] * i % mod;
inv[x+y] = ksm(fac[x+y], mod-2);
for(int i = x+y;i;-- i) inv[i-1] = (LL)inv[i] * i % mod;
notp[0] = notp[1] = true;
for(int i = 2;i <= n;++ i){
if(!notp[i]) mu[pri[tot++] = i] = -1;
for(int j = 0;j < tot && i * pri[j] < N;++ j){
notp[i*pri[j]] = true;
if(i % pri[j]) mu[i*pri[j]] = -mu[i];
else break;
}
}
for(int i = 2;i <= n;++ i) qmo(mu[i] += mu[i-1]);
for(int l = 1, r, x;l <= n;l = r+1){
r = n / (x = n/l);
res = (res + ((LL)pw[r]-pw[l-1]+mod)*calc(x)) % mod;
}
for(int i = 0;i <= x+y;++ i)
ans = (ans + (LL)inv[i]*inv[x+y-i]%mod*F(a-y+i,b-x+i))%mod;
ans = (LL)ans * fac[x+y] % mod;
if(nc == nd){
int tmp = 1;
for(int i = 0;i < nc;++ i)
if(c[i] == '?' && d[i] == '?') qmo(tmp = (tmp<<1) - mod);
else if(c[i] != '?' && d[i] != '?' && c[i] != d[i]){tmp = 0; break;}
ans = (ans + ((LL)pw[n]*pw[n]+mod-res)%mod*tmp)%mod;
}
printf("%d
", ans);
}