题意:求一个区间里,不含有4或62的数字个数。
分析:
1、dp[i][j]---截止到第i位,当前第i位数字为j时不含4或62的数字个数。
#include<cstdio> #include<cstring> #include<cstdlib> #include<cctype> #include<cmath> #include<iostream> #include<sstream> #include<iterator> #include<algorithm> #include<string> #include<vector> #include<set> #include<map> #include<stack> #include<deque> #include<queue> #include<list> #define lowbit(x) (x & (-x)) const double eps = 1e-8; inline int dcmp(double a, double b){ if(fabs(a - b) < eps) return 0; return a > b ? 1 : -1; } typedef long long LL; typedef unsigned long long ULL; const int INT_INF = 0x3f3f3f3f; const int INT_M_INF = 0x7f7f7f7f; const LL LL_INF = 0x3f3f3f3f3f3f3f3f; const LL LL_M_INF = 0x7f7f7f7f7f7f7f7f; const int dr[] = {0, 0, -1, 1, -1, -1, 1, 1}; const int dc[] = {-1, 1, 0, 0, -1, 1, -1, 1}; const int MOD = 1e9 + 7; const double pi = acos(-1.0); const int MAXN = 10000 + 10; const int MAXT = 10000 + 10; using namespace std; int dp[10][15]; int tmp[10]; void init(){ dp[0][0] = 1; for(int i = 1; i <= 6; ++i){//最多六位数,从低位到高位 for(int j = 0; j < 10; ++j){ for(int k = 0; k < 10; ++k){ if(j != 4 && !(j == 6 && k == 2)) dp[i][j] += dp[i - 1][k];//在dp[i - 1][k]的基础上第i位为j } } } } int solve(int x){//solve函数计算的是1~x-1是否是不幸数的情况,并没有考虑数字x。 int cnt = 0; while(x){ tmp[++cnt] = x % 10; x /= 10; } tmp[++cnt] = 0;//防越界 int ans = 0; for(int i = cnt - 1; i >= 1; --i){ for(int j = 0; j < tmp[i]; ++j){//只统计了1~x-1的情况 if(j != 4 && !(tmp[i + 1] == 6 && j == 2)){ ans += dp[i][j]; } } if(tmp[i] == 4 || (tmp[i + 1] == 6 && tmp[i] == 2)) break;//一旦当前枚举到的数中含4或62,则后面无论跟什么数字都不符合要求,停止枚举 } return ans; } int main(){ int n, m; init(); while(scanf("%d%d", &n, &m) == 2){ if(!n && !m) return 0; printf("%d ", solve(m + 1) - solve(n)); } return 0; }
2、dfs的写法:
数位dp关键在于记录重复状态以便于记忆化搜索。
#include<cstdio> #include<cstring> #include<cstdlib> #include<cctype> #include<cmath> #include<iostream> #include<sstream> #include<iterator> #include<algorithm> #include<string> #include<vector> #include<set> #include<map> #include<stack> #include<deque> #include<queue> #include<list> #define lowbit(x) (x & (-x)) const double eps = 1e-8; inline int dcmp(double a, double b){ if(fabs(a - b) < eps) return 0; return a > b ? 1 : -1; } typedef long long LL; typedef unsigned long long ULL; const int INT_INF = 0x3f3f3f3f; const int INT_M_INF = 0x7f7f7f7f; const LL LL_INF = 0x3f3f3f3f3f3f3f3f; const LL LL_M_INF = 0x7f7f7f7f7f7f7f7f; const int dr[] = {0, 0, -1, 1, -1, -1, 1, 1}; const int dc[] = {-1, 1, 0, 0, -1, 1, -1, 1}; const int MOD = 1e9 + 7; const double pi = acos(-1.0); const int MAXN = 10000 + 10; const int MAXT = 10000 + 10; using namespace std; int dp[10][2], digit[10]; int dfs(int len, bool state, bool limit){ if(!len) return 1; if(!limit && dp[len][state] != -1) return dp[len][state]; int ans = 0, up = limit ? digit[len] : 9; for(int i = 0; i <= up; ++i){ if(i == 4 || state && i == 2) continue; ans += dfs(len - 1, i == 6, limit && i == up); } if(!limit) dp[len][state] = ans; return ans; } int solve(int x){ int cnt = 0; while(x){ digit[++cnt] = x % 10; x /= 10; } return dfs(cnt, false, true); } int main(){ int n, m; memset(dp, -1, sizeof dp); while(scanf("%d%d", &n, &m) == 2){ if(!n && !m) return 0; printf("%d ", solve(m) - solve(n - 1)); } return 0; }