逆元链接:https://www.cnblogs.com/zzqc/p/7192436.html
经典的树分治题
#pragma comment("linker,"/STACK:102400000,102400000) #include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define MAXN 100005 #define mod 1000003 typedef long long ll struct E{ int v,d,next; E(){} E(int vv,int nn):v(vv),next(nn){} }e[MAXN<<1]; bool vis[MAXN]; int size,head[MAXN],ans[2];//边和答案 int flag[mod],F[mod],id[MAXN];//哈希 ll ni[mod];//1-mod之间的逆元表 int sum[MAXN],mi,cr;//cr是子树的路径数量 ll val[MAXN],path[MAXN];//点权路径 void init(){ size=0; memset(vis,0,sizeof vis); memset(ans,-1,sizeof ans); memset(head,-1,sizeof head); memset(flag,0,sizeof flag); } void add(int u,int v){ e[size]=E(v,head[u]); head[u]=size++; } //这个函数找路径 void dfs(int u,ll k){ int i,v; sum[u]=1; vis[u]=true; id[cr]=u; path[cr++]=k*val[u]%mod; ll tmp=path[cr-1]; for(i=head[u];i!=-1;i=e[i].next){ v=e[i].v; if(vis[v]) continue; dfs(v,tm); sum[u]+=sum[v]; } vis[u]=false; } ll k; int n,ca; void getans(int a,int b){ if(a>b) swap(a,b); if(ans[0]==-1 || ans[0]>a) ans[0]=a,ans[1]=b; else if(ans[0]==a && ans[1]>b) ans[1]=b; } //找重心 void getroot(int u){ int i,v,mx=0; sum[u]=1; vis[u]=true; for(i=head[u];i!=-1;i=e[i].next){ v=e[i].v; if(vis[v]) continue; getroot(v); sum[u]+=sum[v]; mx=max(mx,sum[v]); } mx=max(mx,sum[0]-sum[u]); if(mx<mi) mi=mx,root=u; vis[u]=false; } void cal(int u,int cnt){ if(cnt==1) return; int i,v,j; mi=n; sum[0]=cnt;//树规模为cnt getroot(u); vis[root]=true; for(i=head[root];i~=-1;i=e[i].next){ v=e[i].v; if(vis[v]) continue; cr=0; dfs(v,1); for(j=0;j<cr;j++){ if(path[j]*val[root]%mod==k)//直接从根出发 getans(root,id[j]); ll tmp=k*ni[path[j]*val[root]%mod]%mod;//k*x逆元=y if(flag[tmp]!=ca) continue;//没有在当前树中找到路径 getans(F[tmp],id[j]);//F[tmp]是值为tmp的path对应的终点编号,id[j]是当前路径的终点编号 } //把当前子树更新到目前哈希表中 for(j=0;j<cr;j++){ int tmp=path[j]; if(flag[tmp]!=ca || F[tmp]>id[j])//如果当前哈希表状态中没有path[j],或者F[tmp]的结点编号没有优化到最小 F[tm]=id[j],flag[tm]=ca; } } ca++;//每次搜完一颗子树,颜色数+1 //这里就是分治 for(i=head[root];i!=-1;i=e[i].next){ if(vis[e[i].v]) continue; cal(e[i].v,sum[e[i].v]); } } //拓展欧几里得打表,求的逆元就是最后的x ll egcd(ll a,ll b,ll &x,ll &y){ ll temp,tempx; if(b==0){ x=1;y=0; return a; } temp=egcd(b,a%b,x,y); tempx=x; x=y; y=tempx-a/b*y; return temp; } int main(){ int u,v,i,j; ll y; for(i=0;i<mod;i++){//拓展欧几里得打表求逆元,也可以递推打表,快速幂打表 egcd(i*1ll,mod*1ll,ni[i],y); ni[i]%=mod,ni[i]=(ni[i]+mod)%mod; } while(scanf("%d%lld",&n,&k)==2){ init(); ca=1;//不能是0 for(i=1;i<=n;i++) scanf("%lld",&val[i]); for(i=1;i<n;i++){ scanf("%d%d",&u,&v); add(u,v),add(v,u); } cal(1,n); if(ans[0]==-1) puts("No solution"); else printf("%d %d ",ans[0],ans[1]); } }