给定一个偶数长度n和字符集(0..9中的一些数字) 问有多少个串的前 (frac{n}{2}) 位的位数和跟后 (frac{n}{2}) 位相等
[fleft( i,j
ight) ext{表示}i ext{个数的和是}j ext{的方案数}
\
ext{答案是}sum_y{fleft( frac{n}{2},y
ight)}^2 ext{(前半部分和后半部分都是}sum_y{fleft( frac{n}{2},y
ight)} ext{)}
\
ext{所以只需要考虑怎么求}f
\
fleft( x+ ext{1,}y
ight) =sum_{i=0}^9{fleft( x,y-i
ight)}
\
ext{可以把上面的柿子补成卷积形式}
\
ext{令}H=fleft( x+1
ight) ext{,}F=fleft( x
ight) ext{,}Hleft( y
ight) ext{表示的就是}fleft( x+ ext{1,}y
ight) ext{,}F ext{同理}
\
Hleft( k
ight) =sum_{i=0}^{min left( k,9
ight)}{Fleft( i
ight)}ast Gleft( k-i
ight)
\
ext{即}fleft( x+1
ight) left( k
ight) =sum_{i=0}^{min left( k,9
ight)}{fleft( x
ight) left( i
ight)}ast Gleft( k-i
ight)
\
ext{比较容易发现没有字符集限制的情况下}G=1
\
ext{有限制的时候}G=sum_{i=0}^9{a_ix^i} ext{,}a_i ext{表示}i ext{是否在允许的字符集中,}x ext{没有实际意义}
\
ext{所以要做的就是求出}Gleft( x
ight) ^{frac{n}{2}} ext{然后统计每一项系数的平方和}
]
贴个Tutorial里的代码
#include<bits/stdc++.h>
using namespace std;
const int LOGN = 21;
const int N = (1 << LOGN);
const int MOD = 998244353;
const int g = 3;
#define forn(i, n) for(int i = 0; i < int(n); i++)
inline int mul(int a, int b)
{
return (a * 1ll * b) % MOD;
}
inline int norm(int a)
{
while(a >= MOD)
a -= MOD;
while(a < 0)
a += MOD;
return a;
}
inline int binPow(int a, int k)
{
int ans = 1;
while(k > 0)
{
if(k & 1)
ans = mul(ans, a);
a = mul(a, a);
k >>= 1;
}
return ans;
}
inline int inv(int a)
{
return binPow(a, MOD - 2);
}
vector<int> w[LOGN];
vector<int> iw[LOGN];
vector<int> rv[LOGN];
void precalc()
{
int wb = binPow(g, (MOD - 1) / (1 << LOGN));
for(int st = 0; st < LOGN; st++)
{
w[st].assign(1 << st, 1);
iw[st].assign(1 << st, 1);
int bw = binPow(wb, 1 << (LOGN - st - 1));
int ibw = inv(bw);
int cw = 1;
int icw = 1;
for(int k = 0; k < (1 << st); k++)
{
w[st][k] = cw;
iw[st][k] = icw;
cw = mul(cw, bw);
icw = mul(icw, ibw);
}
rv[st].assign(1 << st, 0);
if(st == 0)
{
rv[st][0] = 0;
continue;
}
int h = (1 << (st - 1));
for(int k = 0; k < (1 << st); k++)
rv[st][k] = (rv[st - 1][k & (h - 1)] << 1) | (k >= h);
}
}
inline void fft(int a[N], int n, int ln, bool inverse)
{
for(int i = 0; i < n; i++)
{
int ni = rv[ln][i];
if(i < ni)
swap(a[i], a[ni]);
}
for(int st = 0; (1 << st) < n; st++)
{
int len = (1 << st);
for(int k = 0; k < n; k += (len << 1))
{
for(int pos = k; pos < k + len; pos++)
{
int l = a[pos];
int r = mul(a[pos + len], (inverse ? iw[st][pos - k] : w[st][pos - k]));
a[pos] = norm(l + r);
a[pos + len] = norm(l - r);
}
}
}
if(inverse)
{
int in = inv(n);
for(int i = 0; i < n; i++)
a[i] = mul(a[i], in);
}
}
int aa[N], bb[N], cc[N];
inline void multiply(int a[N], int sza, int b[N], int szb, int c[N], int &szc)
{
int n = 1, ln = 0;
while(n < (sza + szb))
n <<= 1, ln++;
for(int i = 0; i < n; i++)
aa[i] = (i < sza ? a[i] : 0);
for(int i = 0; i < n; i++)
bb[i] = (i < szb ? b[i] : 0);
fft(aa, n, ln, false);
fft(bb, n, ln, false);
for(int i = 0; i < n; i++)
cc[i] = mul(aa[i], bb[i]);
fft(cc, n, ln, true);
szc = n;
for(int i = 0; i < n; i++)
c[i] = cc[i];
}
vector<int> T[N];
int a[N];
int b[N];
int c[N];
#define sz(a) (int(a.size()))
int main()
{
precalc();
int n, k;
scanf("%d %d", &n, &k);
for(int i = 0; i < k; i++)
{
int x;
scanf("%d", &x);
a[x] = 1;
}
int nn = 1, ln = 0;
int nw = (n * 5) + 1;
while(nn < nw)
{
nn *= 2;
ln++;
}
fft(a, nn, ln, false);
forn(i, nn)
a[i] = binPow(a[i], n / 2);
fft(a, nn, ln, true);
int ans = 0;
forn(i, nn)
ans = norm(ans + binPow(a[i], 2));
printf("%d
", ans);
return 0;
}