vjudge
题意:给一棵树,每个点上有一个权值,求一条路径使得路径上权值的乘积膜(10^6+3)的结果为(K),输出路径的两个端点(x,y)。如有多解,设(x<y),输出(x)最小的,若仍有多解输出(y)最小的。
sol
点分。
每次考虑所有过重心的路径,开一个桶(T[x])表示到根路径权值乘积(不算根的权值)为(x)的最小节点编号。
注意要先查出所有点到根的权值乘积,全部更新答案,再去更新桶(T)
更新答案的时候用逆元。逆元可以线性预处理出来。
记得要设(T[1]=u),做完这一层之后也要把(T[1])清空
code
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 1e5+5;
const int mod = 1e6+3;
int n,k,inv[mod],val[N],to[N<<1],nxt[N<<1],head[N],cnt;
int sz[N],w[N],root,sum,vis[N],T[mod],dep[N],tmp[N],top,ans1,ans2;
void link(int u,int v){to[++cnt]=v;nxt[cnt]=head[u];head[u]=cnt;}
void getroot(int u,int f)
{
sz[u]=1;w[u]=0;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (v==f||vis[v]) continue;
getroot(v,u);
sz[u]+=sz[v];w[u]=max(w[u],sz[v]);
}
w[u]=max(w[u],sum-sz[u]);
if (w[u]<w[root]) root=u;
}
void getdeep(int u,int f,int sta)
{
dep[u]=sta;tmp[++top]=u;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (v==f||vis[v]) continue;
getdeep(v,u,1ll*sta*val[v]%mod);
}
}
void solve(int u)
{
vis[u]=1;T[1]=u;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (vis[v]) continue;
top=0;getdeep(v,0,val[v]);
for (int i=1;i<=top;++i)
{
int xx=1ll*dep[tmp[i]]*val[u]%mod,yy=1ll*k*inv[xx]%mod;
int x=tmp[i],y=T[1ll*k*inv[1ll*dep[tmp[i]]*val[u]%mod]%mod];
if (!y) continue;
if (x>y) swap(x,y);
if (x<ans1||(x==ans1&&y<ans2)) ans1=x,ans2=y;
}
for (int i=1;i<=top;++i) if (!T[dep[tmp[i]]]||tmp[i]<T[dep[tmp[i]]]) T[dep[tmp[i]]]=tmp[i];
}
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (vis[v]) continue;
top=0;getdeep(v,0,val[v]);
for (int i=1;i<=top;++i) T[dep[tmp[i]]]=0;
}
T[1]=0;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (vis[v]) continue;
sum=sz[v];root=0;
getroot(v,0);
solve(root);
}
}
int main()
{
inv[1]=1;
for (int i=2;i<mod;++i) inv[i]=mod-1ll*(mod/i)*inv[mod%i]%mod;
while (scanf("%d %d",&n,&k)!=EOF)
{
memset(head,0,sizeof(head));cnt=0;
memset(vis,0,sizeof(vis));ans1=ans2=1e9;
for (int i=1;i<=n;++i) val[i]=gi();
for (int i=1;i<n;++i)
{
int u=gi(),v=gi();
link(u,v);link(v,u);
}
root=0;sum=w[0]=n;
getroot(1,0);
solve(root);
if (ans1==1e9) puts("No solution");
else printf("%d %d
",ans1,ans2);
}
}