首页 > 代码库 > UVA 10712 - Count the Numbers (数位DP)

UVA 10712 - Count the Numbers (数位DP)

UVA 10712 - Count the Numbers

题目链接

题意:求区间[A,B]数字中,子串包含N的数字有多少个

思路:数位DP,写了个记忆化乱搞搞过了,dp[i][j][2][2][2],分别表示i位的时候,末尾为j的情况,后面3维用来处理小于的情况,已经出现过子串的情况,前导0的情况,然后注意特判一下数字0的情况,因为一开始要分解数字,而0是不能分解的。

代码:

#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;

int n, a, b, v[32], vn, dp[12][100][2][2][2], mod;

void tra(int num) {
    vn = 0;
    while (num) {
         v[vn++] = num % 10;
         num /= 10;
      }
      for (int i = 0; i < vn / 2; i++)
          swap(v[i], v[vn - i - 1]);
}

int dfs(int num, int wei, int flag1, int flag2, int flag3) {
    int &ans = dp[num][wei][flag1][flag2][flag3];
    if (ans != -1) return ans;
    ans = 0;
    if (num == vn) {
        if (flag2) ans = 1;
        return ans;
    }
    if (!flag3) {
        if (num + 1 == vn && n == 0) ans += dfs(num + 1, 0, 1, 1, 1);
        else ans += dfs(num + 1, 0, 1, 0, 0);
        for (int i = 1; i <= (flag1 ? 9 : v[num]); i++) {
              int tmp1 = flag1, tmp2 = flag2;
            if (i < v[num]) tmp1 = 1;
            if (wei * 10 + i == n) tmp2 = 1;
            ans += dfs(num + 1, (wei * 10 + i) % mod, tmp1, tmp2, 1);
          }
     }
    else if (flag2) {
        for (int i = 0; i <= (flag1 ? 9 : v[num]); i++) {
            int tmp1 = flag1;
            if (i < v[num]) tmp1 = 1;
            ans += dfs(num + 1, (wei * 10 + i) % mod, tmp1, flag2, flag3);
          }
     }
     else {
         for (int i = 0; i <= (flag1 ? 9 : v[num]); i++) {
             int tmp1 = flag1, tmp2 = flag2;
            if (i < v[num]) tmp1 = 1;
            if (wei * 10 + i == n) tmp2 = 1;
            ans += dfs(num + 1, (wei * 10 + i) % mod, tmp1, tmp2, flag3);
         }
     }
     return ans;
}

int solve(int num) {
    if (num == 0 && n == 0) return 1;
    if (num <= 0) return 0;
    memset(dp, -1, sizeof(dp));
    tra(num);
    return dfs(0, 0, 0, 0, 0);
}

int main() {
    while (~scanf("%d%d%d", &a, &b, &n) && a != -1) {
        if (n >= 0 && n < 10) mod = 1;
        if (n >= 10 && n < 100) mod = 10;
        if (n >= 100) mod = 100;
        printf("%d\n", solve(b) - solve(a - 1));
     }
    return 0;
}