首先赞一下题目, 好题
题意:
Marjar University has decided to upgrade the infrastructure of school intranet by using fiber-optic technology. There are N buildings in the school. Each building will be installed with one router. These routers are connected by optical cables in such a way that there is exactly one path between any two routers.
Each router should be initialized with an operating frequency Fi before it starts to work. Due to the limitations of hardware and environment, the operating frequency should be an integer number within [Li, Ri]. In order to reduce the signal noise, the operating frequency of any two adjacent routers should be co-prime.
Edward is the headmaster of Marjar University. He is very interested in the number of different ways to initialize the operating frequency. Please write a program to help him! To make the report simple and neat, you only need to calculate the sum of Fi (modulo 1000000007) in all solutions for each router.
英文自己看。 大概意思就是一棵树上每一个点i能够给一个L[i] ~ R[i]的值, 相邻的两个点的值要互质, 问每一个点的全部情况的值的和, mod 1000000007
大概思路就是枚举一个点, 然后枚举这个点上的值(i)。 然后求出这种情况有多少种(dp[i]), 那么这个点上的答案就是 dp[1] * 1 + dp[2] * 2 + dp[3] * 3 + .........
然后就是树形dp了, 转移方程就是 dp[i][j] =π{(∑{dp[t][k] | gcd(j, k) == 1}) | i 和 t 相邻}
可是这样转移的复杂度是50 * 50000 * 50000, 就算15秒时限也会超时
所以我们能够考虑用不互质的来转移。
设s[i] = ∑dp[i][j]
那么转移方程就是 dp[i][j] =π{(s[i] - ∑{dp[t][k] | gcd(j, k) != 1}) | i 和 t 相邻}
对于dp[i][j],我们能够把 j 质因数分解, 如果 j = p1^e1 * p2^e2 * p3^e3。 50000以内的数最多有6个不同的质因数;
然后我们记录一下 div[i][k] = ∑{ dp[i][j] | j 是 k 的倍数}, 这个能够nlogn的复杂度处理出来;
这样 ∑{dp[t][k] | gcd(j, k) != 1} = div[t][p1] + div[t][p2] + div[t][p3] - div[t][p1 * p2] - div[t][p1 * p3] - div[t][p2 * p3] + div[t][p1 * p2 * p3], 这样能够用容斥原理算了, 复杂度最多为2 ^ 6;
这样dp一次复杂度大概就是 50 * 50000 * (log50000 + 2^6);
要算50个点的话。 还是会超时;
可是这是一颗树。 对每一个点都dp一次的话算了非常多反复的东西, 所以我们不要每次都去所有dp一次, 比如算完i点的了, 要去算j点的, 如果i j相邻。 那么在dp数组中仅仅有i和j的值有变化。 我们就仅仅要再算这两个点的dp转移就够了。
很多其它细节请看代码:
#include <iostream> #include <cstdio> #include <cmath> #include <cstring> #include <ctime> #include <vector> using namespace std; typedef long long LL; const int N = 50009; const LL M = 1000000007; inline void addIt(int &a, int b) { a += b; if(a >= M) a -= M; } inline int sub(int a, int b) { a -= b; if(a < 0) a += M; if(a >= M) a -= M; return a; } struct Num { int p[11]; int allp; }num[N]; struct Data { int dp[N], div[N], all; }data[55], fb[55]; int L[55], R[55]; int ans[55]; int n; vector<int> e[55]; bool vis[55]; int rcn, rcv; void print() { for(int i = 0; i < n; i++) { printf("i = %d ", i); for(int j = 0; j < 6; j++) printf("dp[%d] = %d ", j, data[i].dp[j]); } } int rc(int i, int xs) { int t, r = 0; for(; i < num[rcn].allp; i++) { if(xs * num[rcn].p[i] <= R[rcv]) t = data[rcv].div[xs * num[rcn].p[i]]; else t = 0; //printf("***%d, t = %d ", xs * num[rcn].p[i], t); addIt(r, sub(t, rc(i + 1, xs * num[rcn].p[i]))); } //printf("r = %d ", r); return r; } void dfsTree(int pre, int now) { if(vis[now]) return; vis[now] = true; int i, siz = e[now].size(); for(i = 0; i < siz; i++) dfsTree(now, e[now][i]); if(pre >= 0 && siz <= 1) for(i = L[now]; i <= R[now]; i++) data[now].dp[i] = 1; else for(i = L[now]; i <= R[now]; i++) { data[now].dp[i] = 1; for(int j = 0; j < siz; j++) if(e[now][j] != pre) { rcn = i; rcv = e[now][j]; data[now].dp[i] = (LL)(data[now].dp[i]) * sub(data[e[now][j]].all, rc(0, 1)) % M; } } data[now].all = 0; for(i = 1; i <= R[now]; i++) { data[now].div[i] = 0; for(int j = i; j <= R[now]; j += i) if(j >= L[now]) addIt(data[now].div[i], data[now].dp[j]); if(i >= L[now]) addIt(data[now].all, data[now].dp[i]); } } void dfs(int pre, int now, int deep) { dfsTree(-1, now); //if(now == 1) print(); int i, siz = e[now].size(); ans[now] = 0; for(i = L[now]; i <= R[now]; i++) addIt(ans[now], (LL)data[now].dp[i] * i % M); //fb[deep] = data[now]; //vis[now] = false; for(i = 0; i < siz; i++) if(e[now][i] != pre) { vis[e[now][i]] = false; fb[deep] = data[e[now][i]]; vis[now] = false; dfs(now, e[now][i], deep + 1); data[e[now][i]] = fb[deep]; vis[e[now][i]] = true; } } void init() { int i, j, k; for(i = 0; i < N; i++) num[i].allp = 0; for(i = 2; i < N; i++) if(num[i].allp == 0) for(j = i; j < N; j += i) num[j].p[num[j].allp++] = i; } int main() { //freopen("13F.in", "r", stdin); init(); int T; scanf("%d", &T); while(T--) { scanf("%d", &n); int i, j, k; for(i = 0; i < n; i++) { scanf("%d", &L[i]); } for(i = 0; i < n; i++) { scanf("%d", &R[i]); e[i].clear(); } for(i = 0; i < n - 1; i++) { scanf("%d %d", &j, &k); j--; k--; e[j].push_back(k); e[k].push_back(j); } memset(vis, false, sizeof(vis)); memset(data, 0, sizeof(data)); dfsTree(-1, 0); //print(); dfs(-1, 0, 0); for(i = 0; i < n - 1; i++) printf("%d ", ans[i]); printf("%d ", ans[i]); } return 0; }