code:
#include <bits/stdc++.h> #define N 200004 #define ll long long #define mod 1000000007 #define setIO(s) freopen(s".in","r",stdin) using namespace std; struct Lowbit { ll M[N]; int C[N]; int lowbit(int t) { return t&(-t); } void add(int x,int v) { while(x<N) C[x]+=v, x+=lowbit(x); } void mul(int x,int v) { while(x<N) M[x]=(ll)M[x]*v%mod,x+=lowbit(x); } int qsum(int x) { int re=0; while(x>0) re+=C[x],x-=lowbit(x); return re; } int qmul(int x) { ll re=1ll; while(x>0) re=re*M[x]%mod,x-=lowbit(x); return re; } void clr(int x) { while(x<N) M[x]=1,x+=lowbit(x); } }addv,mulv; ll ans=1ll; int root,edges,sn,tot; int hd[N],to[N<<1],nex[N<<1],col[N<<1],val[N<<1],size[N],mx[N],vis[N]; void addedge(int u,int v,int vv,int c) { nex[++edges]=hd[u],hd[u]=edges,to[edges]=v,val[edges]=vv,col[edges]=c; } struct node { int x,y,op,v; node(int x=0,int y=0,int op=0,int v=0):x(x),y(y),op(op),v(v){} }q[N]; bool cmp(node a,node b) { return (a.x==b.x&&a.y==b.y)?(a.z>b.z):(a.x==b.x?a.y<b.y:a.x<b.x); } inline int qpow(int x,int y) { int re=1; for(;y;y>>=1,x=1ll*x*x%mod) if(y&1) re=1ll*re*x%mod; return re; } void getroot(int u,int ff) { size[u]=1,mx[u]=0; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff||vis[v]) continue; getroot(v,u); size[u]+=size[v]; mx[u]=max(mx[u],size[v]); } mx[u]=max(mx[u],sn-size[u]); if(mx[u]<mx[root]) root=u; } void dfs(int u,int ff,int x,int y,int v) { q[++tot]=node(2*y-x,y-2*x,0,v); q[++tot]=node(x-2*y,2*x-y,1,v); for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(v==ff||vis[v]) continue; dfs(v,u,x+(col[i]==0),y+(col[i]==1),1ll*v*val[i]%mod); } } void calc(int u,int flag,int pre,int v) { tot=0; dfs(u,0,pre==0,pre==1,v); sort(q+1,q+1+tot,cmp); for(int i=1;i<=tot;++i) { if(q[i].op==0) { addv.add(q[i].y,1); mulv.mul(q[i].y,q[i].v); } else { int a1=mulv.qmul(q[i].y); int a2=addv.qsum(q[i].y); int delta=1ll*a1*qpow(q[i].v,a2)%mod; if(flag==-1) delta=qpow(delta,mod-2); ans=1ll*ans*delta%mod; } } for(int i=1;i<=tot;++i) { if(!q[i].op) { addv.add(q[i].y,-1); mulv.clr(q[i].y); } } } void solve(int u) { calc(u,1,-1,1); vis[u]=1; for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(vis[v]) continue; calc(v,-1,col[i],val[i]); } for(int i=hd[u];i;i=nex[i]) { int v=to[i]; if(vis[v]) continue; sn=size[v],root=0,getroot(v,u),solve(root); } } int main() { setIO("input"); memset(mulv.M,1,sizeof(mulv.M)); int i,j,n; scanf("%d",&n); for(i=1;i<n;++i) { int x,y,z,c; scanf("%d%d%d%d",&x,&y,&z,&c); add(x,y,z,c),add(y,x,z,c); } sn=mx[0]=n,getroot(1,0),solve(root); printf("%lld ",ans); return 0; }