昨天学了下树分治,今天补这道题,还是太不熟练了,写完之后一直超时。后来查出好多错= =比如v,u写倒了,比如+写成了取最值,比如。。。。爆int。。。查了两个多小时的错。。哭。。。(没想到进首页了
http://hzwer.com/6107.html 大神博客,代码清晰,照着这个改的
逆元预处理之前是没有见过的,学习了。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <map> #include <vector> typedef long long ll; using namespace std; const int N = 100005; const int MOD = 1000003; inline int read() { int x=0;char ch=getchar(); while(ch<'0'||ch>'9')ch=getchar(); while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x; } struct Edge { int to, next; } edge[N*2]; int head[N]; int edge_cnt; int a[N], sz[N], inv[MOD]; bool used[N]; int root, minsz, size; int ansx, ansy; int tmp[N], id[N], cnt; int mp[MOD]; int n, k; void up(int &x, int y) { if(y>x) x=y; } void query(int v, int x) { int ser = (ll)k*inv[v]%MOD; int y = mp[ser]; if (y == 0 || x == y) return ; if (y < x) swap(x, y); if (x < ansx ||(x == ansx && y < ansy)) ansx=x, ansy=y; } void add_edge(int u, int v) { edge[edge_cnt].to = v; edge[edge_cnt].next = head[u]; head[u] = edge_cnt++; } void get_root(int u, int fa) { sz[u] = 1; int maxn= 0; for (int i = head[u]; ~i; i = edge[i].next) { int v = edge[i].to; if (v == fa || used[v]) continue; get_root(v, u); sz[u] += sz[v]; up(maxn, sz[v]); } up(maxn, size-sz[u]); if (maxn < minsz) minsz=maxn, root=u; } void dfs(int u, int fa, int val) { tmp[cnt] = val; id[cnt++] = u; for (int i = head[u]; ~i; i = edge[i].next) { int v = edge[i].to; if (v == fa || used[v]) continue; dfs(v, u, (ll)val*a[v]%MOD); } } void solve(int u) { used[u] = true; mp[ a[u] ] = u; // 计算经过u的所有乘积为k的点对 // 对每一个子节点处理 防止找的的点对是同一个子树的 for (int i = head[u]; ~i; i = edge[i].next) { int v = edge[i].to; if (used[v]) continue; cnt = 0; dfs(v, u, a[v]); for (int j = 0; j < cnt; ++j) query(tmp[j], id[j]); for (int j = 0; j < cnt; ++j) { tmp[j] = (ll)tmp[j]*a[u]%MOD; int &now = mp[tmp[j]]; if (now == 0 || now > id[j]) mp[tmp[j]]=id[j]; } } // 删除所有记录 因为处理子树内时相互没有影响 mp[a[u]] = 0; for (int i = head[u]; ~i; i = edge[i].next) { int v = edge[i].to; if (used[v]) continue; cnt = 0; dfs(v, u, (ll)a[v]*a[u]%MOD); for (int j = 0; j < cnt; ++j) { mp[tmp[j]] = 0; } } for (int i = head[u]; ~i; i = edge[i].next) { int v = edge[i].to; if (used[v]) continue; size = sz[v]; minsz = n+1; get_root(v, 0); solve(root); } } int main() { inv[1]=1; for(int i=2;i<MOD;i++) { int a=MOD/i,b=MOD%i; inv[i]=((ll)inv[b]*(-a)%MOD+MOD)%MOD; } while (~scanf("%d%d", &n, &k)) { for (int i = 1; i <= n; ++i) a[i] = read(); int u, v; memset(head, -1, sizeof head); edge_cnt = 0; memset(used, 0, sizeof used); for (int i = 1; i < n; ++i) { u = read(); v = read(); add_edge(u, v); add_edge(v, u); } minsz = n+1; size = n; get_root(1, 0); ansx = ansy = MOD; solve(root); if (ansx == MOD) puts("No solution"); else printf("%d %d ", ansx, ansy); } return 0; }