用类似于数位dp的方式, 去求每个数字的贡献。。 好像我写得巨麻烦。
其实转化一下之后, 有很好写的方法。
#include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0); using namespace std; const int N = 700 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T& a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T& a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T& a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T& a, S b) {return a > b ? a = b, true : false;} int n, cnt, Pow[N]; int v[N]; char s[N]; int dp[N]; int f[N][N][2]; int g[N][N][2]; int sum[N][N][2]; int pre[N][N][2]; int d, w; int getRet1(int p, int ban, bool limit) { if(p == -1) return 1; if(!limit && ~dp[p]) return dp[p]; int ret = 0; int up = limit ? v[p] : 9; for(int i = 0; i <= up; i++) { if(i == ban) continue; add(ret, getRet1(p - 1, ban, limit && (i == up))); } if(!limit) dp[p] = ret; return ret; } int getRet2(int p, int big, int big2, bool have, bool limit) { if(p == -1) { return (have && big > w) || (have && big2 <= w); } if(!limit) { int need = max(0, w - big + 1); int need2 = min(p + 1, w - big2); int ret = 0; if(have) ret = (sum[p][need][0] + sum[p][need][1]) % mod; else ret = sum[p][need][1]; int gg = ret; if(need2 >= 0) { if(have) add(ret, (pre[p][need2][0] + pre[p][need2][1]) % mod); else add(ret, pre[p][need2][1]); } return ret; } int up = limit ? v[p] : 9; int ret = 0; for(int i = 0; i <= up; i++) { add(ret, getRet2(p - 1, big + (i > d), big2 + (i >= d), have || (i == d), limit && (i == up))); } return ret; } int solve(int x) { int ret = 0; memset(f, 0, sizeof(f)); memset(g, 0, sizeof(g)); memset(dp, -1, sizeof(dp)); memset(sum, 0, sizeof(sum)); f[0][1][0] = 9 - x; f[0][0][1] = 1; f[0][0][0] = x; for(int i = 1; i < n; i++) { for(int j = i + 1; j >= 0; j--) { add(f[i][j][0], 1LL * f[i - 1][j][0] * x % mod); if(j) add(f[i][j][0], 1LL * f[i - 1][j - 1][0] * (9 - x) % mod); add(f[i][j][1], 1LL * f[i - 1][j][1] * (x + 1) % mod); if(j) add(f[i][j][1], 1LL * f[i - 1][j - 1][1] * (9 - x) % mod); add(f[i][j][1], f[i - 1][j][0]); } } g[0][1][0] = 9 - x; g[0][1][1] = 1; g[0][0][0] = x; for(int i = 1; i < n; i++) { for(int j = i + 1; j >= 0; j--) { add(g[i][j][0], 1LL * g[i - 1][j][0] * x % mod); if(j) add(g[i][j][0], 1LL * g[i - 1][j - 1][0] * (9 - x) % mod); add(g[i][j][1], 1LL * g[i - 1][j][1] * x % mod); if(j) add(g[i][j][1], 1LL * g[i - 1][j - 1][1] * (10 - x) % mod); if(j) add(g[i][j][1], 1LL * g[i - 1][j - 1][0]); } } for(int i = 0; i < n; i++) { for(int j = i + 1; j >= 0; j--) { sum[i][j][0] = (f[i][j][0] + sum[i][j + 1][0]) % mod; sum[i][j][1] = (f[i][j][1] + sum[i][j + 1][1]) % mod; } } for(int i = 0; i < n; i++) { for(int j = 0; j <= i + 1; j++) { pre[i][j][0] = g[i][j][0]; if(j) add(pre[i][j][0], pre[i][j - 1][0]); pre[i][j][1] = g[i][j][1]; if(j) add(pre[i][j][1], pre[i][j - 1][1]); } } int ncnt = (cnt - getRet1(n - 1, x, 1) + mod) % mod; for(int i = 0; i < n; i++) { int tmp = ncnt; d = x, w = i; sub(tmp, getRet2(n - 1, 0, 0, 0, 1)); add(ret, 1LL * Pow[i] * tmp % mod * x % mod); } return ret; } int main() { for(int i = Pow[0] = 1; i < N; i++) Pow[i] = 1LL * Pow[i - 1] * 10 % mod; scanf("%s", s); n = strlen(s); reverse(s, s + n); for(int i = n - 1; i >= 0; i--) { v[i] = s[i] - '0'; cnt = 1LL * cnt * 10 % mod; add(cnt, v[i]); } add(cnt, 1); int ans = 0; for(int i = 1; i <= 9; i++) add(ans, solve(i)); printf("%d ", ans); return 0; } /* */