Description
程序员 ZS 有一棵树,它可以表示为 (n) 个顶点的无向连通图,顶点编号从 (0) 到 (n-1),它们之间有 (n-1) 条边。每条边上都有一个非零的数字。
一天,程序员 ZS 无聊,他决定研究一下这棵树的一些特性。他选择了一个十进制正整数 (M),(gcd(M,10)=1)。
对于一对有序的不同的顶点 ((u, v)),他沿着从顶点 (u) 到顶点 (v)的最短路径,按经过顺序写下他在路径上遇到的所有数字(从左往右写),如果得到一个可以被 (M) 整除的十进制整数,那么就认为 ((u,v)) 是有趣的点对。
帮助程序员 ZS 得到有趣的对的数量。
Hint
- (1le nle 10^5)
- (1le mle 10^9,gcd(m, 10) = 1)
- (1le ext{边权} < 10)
Solution
这种树上路径的统计问题基本都是 点分治,而点分治的重点和难点就是如何 统计经过分治中心的满足条件的路径的个数。
这里采用 容斥法:即现分治中心为 (s),当前答案等于整个子树 (s) 的答案减去以 (s) 各个子结点为根的子树的答案。
考虑如何统计。
我们设有一条路径是 (x ightarrow y),分治中心为 (s),路径 (x ightarrow s) 对应的数字为 (pd),(s ightarrow y) 对应 (nd),(s) 到 (y) 的距离为 (l)。
那么只有 (pd imes 10^l + nd equiv 0 pmod m) 成立时满足要求。
变形一下:(pd equiv -nd imes 10^{-l}pmod m)。
于是我们可以这样搞:把所有的 (pd) 用 map
存起来,记录一下个数,用 pair
数组把 ((nd, l)) 记录下来。
导入所有了路径信息后,枚举 pair
数组,查找 map
中的元素配对即可。
预处理一下 (10) 的幂及其逆元的话,时间复杂度 (O(nlog^2 n))。如果用 Hash Table 可以优化到理论 (O(nlog n)),但没什么必要。
Code
/*
* Author : _Wallace_
* Source : https://www.cnblogs.com/-Wallace-/
* Problem : Codeforces 715E Digit Tree
*/
#include <cstdio>
#include <map>
#include <utility>
#include <vector>
using namespace std;
const int N = 1e5 + 5;
namespace Inv {
void extgcd(long long a, long long b, long long& x, long long& y) {
if (!b) x = 1, y = 0;
else extgcd(b, a % b, y, x), y -= a / b * x;
}
inline long long get(long long b, long long p) {
long long x, y;
extgcd(b, p, x, y);
x = (x % p + p) % p;
return x;
}
}
int n, m;
long long p10[N], invp[N];
long long ans;
struct edge { int to, len; };
vector<edge> G[N];
int root;
int maxp[N], size[N];
bool centr[N];
int getSize(int x, int f) {
size[x] = 1;
for (auto y : G[x])
if (!centr[y.to] && y.to != f)
size[x] += getSize(y.to, x);
return size[x];
}
void getCentr(int x, int f, int t) {
maxp[x] = 0;
for (auto y : G[x])
if (!centr[y.to] && y.to != f) {
getCentr(y.to, x, t);
maxp[x] = max(maxp[x], size[y.to]);
}
maxp[x] = max(maxp[x], t - size[x]);
if (maxp[x] < maxp[root]) root = x;
}
vector<pair<long long, int> > dat;
map<long long, int> cnt;
void getData(int x, int f, long long pd, long long nd, int dep) {
if (dep >= 0) cnt[pd]++, dat.push_back(make_pair(nd, dep));
for (auto y : G[x]) {
if(centr[y.to] || y.to == f) continue;
long long tpd = (pd + y.len * p10[dep + 1] % m) % m;
long long tnd = (nd * 10 % m + y.len) % m;
getData(y.to, x, tpd, tnd, dep + 1);
}
}
inline long long count(int x, int d) {
long long ret = 0;
cnt.clear(), dat.clear();
if (d == 0) getData(x, 0, 0, 0, -1);
else getData(x, 0, d % m, d % m, 0);
for (auto p : dat) {
long long t = ((-p.first * invp[p.second + 1] % m) + m) % m;
if (cnt.count(t)) ret += cnt[t];
if (d == 0 && p.first == 0) ++ret;
}
return ret + (d == 0 ? cnt[0] : 0);
}
void solve(int x) {
maxp[root = 0] = N;
getCentr(x, 0, getSize(x, 0));
int s = root; centr[s] = true;
for (auto y : G[s])
if (!centr[y.to])
solve(y.to);
ans += count(s, 0);
for (auto y : G[s])
if (!centr[y.to])
ans -= count(y.to, y.len);
centr[s] = false;
}
signed main() {
scanf("%d%d", &n, &m);
for (register int i = 1; i < n; i++) {
int u, v, l;
scanf("%d%d%d", &u, &v, &l);
++u, ++v;
G[u].push_back(edge{v, l});
G[v].push_back(edge{u, l});
}
p10[0] = 1 % m;
for (register int i = 1; i <= n; i++)
p10[i] = p10[i - 1] * 10 % m;
invp[n] = Inv::get(p10[n], m);
for (register int i = n - 1; i; i--)
invp[i] = invp[i + 1] * 10 % m;
ans = 0, solve(1);
printf("%lld
", ans);
return 0;
}