令 (T = 2*3*5*7*11*47)
(f[i][j]) 表示确定 (i) 位数,模 (T) 的余数为 (j) 的方案数(类似于数位 (dp) 直接转移就好了)。
然后发现 (T) 可以优化, (T = 3*7*11*47) ,我们只要让最终的数不被这 (4) 个数整除,然后最后一位强制不填 (2,5),只填 (1,3,7)
然后倍增优化
令(pw = 10^frac{i}{2}) , 则有 (f[i][j*pw+k] = sum f[frac{i}{2}][j] * f[frac{i}{2}][k])
本来需要 (n*T) 次转移,现在需要 (T^2logn) 次,但还能优化那个 (T^2)
固定后面的 (f[frac{i}{2}][k]) (可以将其看为 (h[k]) ),对于所有合法的的 (f[frac{i}{2}][j]) 求和得到 (g) 数组,就可以将 (T^2) 的转移用 (FFT) 优化到 (TlogT) 了
(g[j] = sumlimits_{x*pwequiv j (mod T)} f[frac{i}{2}][x])
(f[i][j+k] = sum g[j]*h[k])
时间复杂度 (O(TlogTlogn))
注意事项:
-
快速幂算 (10^{frac{i}{2}}) 时千万记得模的是 (T)
-
(f[i][j]) 应该是卷积结果的第 (j) 项以及第 (j+T) 项的和
-
因为 (frac{i}{2}) 是下取整,如果 (i) 是奇数,用 (f[frac{i}{2}]) 转移出来的是 (f[i-1]) ,所以再暴力 (O(T)) 转移一下得到 (f[i])
代码
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
#define db long double
#define rint register int
using namespace std;
const int mod = 9973, T = 10857, N = 1<<15|10;
const double pi = M_PI;
int n, rev[N], nxt[T+10][5], f[T+10], lim, bit;
ll ff[T+10];
struct Com {
db x, y;
Com operator + (const Com &B) { return (Com){x + B.x, y + B.y}; }
Com operator - (const Com &B) { return (Com){x - B.x, y - B.y}; }
Com operator * (const Com &B) { return (Com){x * B.x - y * B.y, x * B.y + y * B.x}; }
} g[N], h[N];
int Pow(rint a, rint x, rint ans = 1) {
for(;x;x >>= 1, a = 1ll * a * a % T)
if(x&1) ans = 1ll * ans * a % T;
return ans;
}
void FFT(Com *a, rint opt) {
for(rint i = 0;i < lim; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(rint d = 1;d < lim;d <<= 1) {
rint st = d<<1;
Com w1 = (Com){cos(pi/d), sin(pi/d) * opt};
for(rint i = 0;i < lim; i += st) {
Com w = (Com){1, 0};
for(rint j = 0;j < d; ++j, w = w * w1) {
Com x = a[i+j], y = w * a[i+j+d];
a[i+j] = x + y;
a[i+j+d] = x - y;
}
}
}
if(~opt) return;
for(rint i = 0;i < lim; ++i) a[i].x = a[i].x/lim+0.5;
}
void Solve(rint x) {
if(!x) return f[0] = 1, void();
Solve(x/2);
rint pw = Pow(10, x/2);
for(rint i = 0;i < lim; ++i) h[i] = g[i] = (Com){0, 0};
for(rint i = 0;i < T; ++i) h[i].x = f[i], g[1ll*pw*i%T].x += f[i];
FFT(h, 1), FFT(g, 1);
for(rint i = 0;i < lim; ++i) h[i] = h[i] * g[i];
FFT(h, -1);
for(rint i = 0;i < T; ++i) f[i] = ((ll)h[i].x + (ll)h[i+T].x) % mod;
if(x&1) {
memset(ff, 0, sizeof(ff));
for(rint i = 0;i < T; ++i) {
ff[nxt[i][0]] += f[i];
ff[nxt[i][1]] += f[i];
ff[nxt[i][2]] += f[i];
ff[nxt[i][3]] += f[i];
ff[nxt[i][4]] += f[i];
}
for(rint i = 0;i < T; ++i) f[i] = ff[i] % mod;
}
}
int main() {
scanf("%d", &n);
lim = 1, bit = 0;
while(lim <= T*2) lim <<= 1, ++bit; --bit;
for(rint i = 0;i < lim; ++i) rev[i] = (rev[i>>1]>>1) | ((i&1)<<bit);
for(rint i = 0;i < T; ++i) {
nxt[i][0] = (10*i+1)%T;
nxt[i][1] = (10*i+2)%T;
nxt[i][2] = (10*i+3)%T;
nxt[i][3] = (10*i+5)%T;
nxt[i][4] = (10*i+7)%T;
}
Solve(n-1);
memset(ff, 0, sizeof(ff));
for(rint i = 0;i < T; ++i) {
ff[nxt[i][0]] += f[i];
ff[nxt[i][2]] += f[i];
ff[nxt[i][4]] += f[i];
}
rint ret = 0;
for(rint i = 0;i < T; ++i) if(i % 3 && i % 7 && i % 11 && i % 47) ret += ff[i] % mod;
printf("%d
", ret % mod);
return 0;
}