题意:给定一棵n个点带边权的树,定义每条路径的值为路径上边权的异或和
如果一条路径的值为0,其对答案的贡献为所有包含这条路径的路径条数
求答案膜1e9+7
n<=1e5,0<=边权<=1e18
思路:
做法一:点分治
参考https://dudulu.net/blog/?p=1654
考场上还剩2小时的时候开的这题,乍一看觉得很可做就是个裸的点分,结果发现自己不会算贡献,场上连样例都没调出来
现在也是写了两天发现不会算贡献,参考了dalao的博客才会
每条路径的权值可以用两个端点分别能扩展的点的个数的乘积来算
预处理一下以1为根每个点的父亲和子树大小
在分治过程中,假设有u->……->pre->v这条路径
对于v这个点:
如果pre=f[v]可扩展的个数就是siz[v]
若果pre!=f[v]可扩展的个数就是n-siz[pre]
对于分治中心u和u的每个直接分支v:
如果u=f[v]可扩展的个数就是n-siz[v]
如果u!=f[v]可扩展的个数就是siz[u]
扔到map里记一下前缀和
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef unsigned int uint; 5 typedef unsigned long long ull; 6 typedef pair<int,int> PII; 7 typedef pair<ll,ll> Pll; 8 typedef vector<int> VI; 9 typedef vector<PII> VII; 10 //typedef pair<ll,ll>P; 11 #define N 200010 12 #define M 200010 13 #define fi first 14 #define se second 15 #define MP make_pair 16 #define pb push_back 17 #define pi acos(-1) 18 #define mem(a,b) memset(a,b,sizeof(a)) 19 #define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++) 20 #define per(i,a,b) for(int i=(int)a;i>=(int)b;i--) 21 #define lowbit(x) x&(-x) 22 #define Rand (rand()*(1<<16)+rand()) 23 #define id(x) ((x)<=B?(x):m-n/(x)+1) 24 #define ls p<<1 25 #define rs p<<1|1 26 27 const ll MOD=1e9+7,inv2=(MOD+1)/2; 28 double eps=1e-6; 29 ll INF=1e15; 30 int dx[4]={-1,1,0,0}; 31 int dy[4]={0,0,-1,1}; 32 33 struct node 34 { 35 int x; 36 ll y; 37 }q[N]; 38 39 map<ll,ll>mp1,mp2; 40 map<ll,ll>::iterator it; 41 int head[N],vet[N],nxt[N],c[N],flag[N],son[N], 42 s[N],siz[N],f[N],pre[N],tot,root,sum,n,now; 43 ll len[N],dis[N],ans; 44 45 int read() 46 { 47 int v=0,f=1; 48 char c=getchar(); 49 while(c<48||57<c) {if(c=='-') f=-1; c=getchar();} 50 while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar(); 51 return v*f; 52 } 53 54 void add(int a,int b,ll c) 55 { 56 nxt[++tot]=head[a]; 57 vet[tot]=b; 58 len[tot]=c; 59 head[a]=tot; 60 } 61 62 void getroot(int u,int fa) 63 { 64 son[u]=1; c[u]=0; 65 int e=head[u]; 66 while(e) 67 { 68 int v=vet[e]; 69 if(v!=fa&&!flag[v]) 70 { 71 getroot(v,u); 72 son[u]+=son[v]; 73 c[u]=max(c[u],son[v]); 74 } 75 e=nxt[e]; 76 } 77 c[u]=max(c[u],sum-c[u]); 78 if(c[u]<c[root]) root=u; 79 } 80 81 void dfs(int u,int fa) 82 { 83 siz[u]=1; 84 int e=head[u]; 85 while(e) 86 { 87 int v=vet[e]; 88 if(v!=fa) 89 { 90 f[v]=u; 91 dfs(v,u); 92 siz[u]+=siz[v]; 93 } 94 e=nxt[e]; 95 } 96 } 97 98 void getdis(int u,int fa) 99 { 100 if(fa==f[u]) 101 { 102 mp1[dis[u]]=(mp1[dis[u]]+siz[u])%MOD; 103 if(dis[u]==0) ans=(ans+1ll*siz[u]*now%MOD)%MOD; 104 } 105 else 106 { 107 mp1[dis[u]]=(mp1[dis[u]]+n-siz[fa])%MOD; 108 if(dis[u]==0) ans=(ans+1ll*(n-siz[fa])*now%MOD)%MOD; 109 } 110 int e=head[u]; 111 while(e) 112 { 113 int v=vet[e]; 114 if(v!=fa&&!flag[v]) 115 { 116 dis[v]=dis[u]^len[e]; 117 getdis(v,u); 118 } 119 e=nxt[e]; 120 } 121 } 122 123 void calc(int u) 124 { 125 dis[u]=0; 126 int e=head[u]; 127 while(e) 128 { 129 int v=vet[e]; 130 if(flag[v]) 131 { 132 e=nxt[e]; 133 continue; 134 } 135 mp1.clear(); 136 if(f[v]==u) now=n-siz[v]; 137 else now=siz[u]; 138 dis[v]=len[e]; 139 getdis(v,u); 140 it=mp1.begin(); 141 while(it!=mp1.end()) 142 { 143 ll w=it->fi,t=it->se; 144 ans=(ans+t*mp2[w]%MOD)%MOD; 145 it++; 146 } 147 it=mp1.begin(); 148 while(it!=mp1.end()) 149 { 150 ll w=it->fi,t=it->se; 151 mp2[w]=(mp2[w]+t)%MOD; 152 it++; 153 } 154 e=nxt[e]; 155 } 156 mp2.clear(); 157 } 158 159 void solve(int u) 160 { 161 flag[u]=1; 162 calc(u); 163 int e=head[u]; 164 while(e) 165 { 166 int v=vet[e]; 167 if(!flag[v]) 168 { 169 sum=son[v]; root=0; 170 getroot(v,0); 171 solve(root); 172 } 173 e=nxt[e]; 174 } 175 } 176 177 int main() 178 { 179 //freopen("1.in","r",stdin); 180 //freopen("1.out","w",stdout); 181 n=read(); 182 rep(i,1,n) head[i]=flag[i]=0; 183 tot=0; 184 rep(i,2,n) 185 { 186 int x=read(); 187 ll y; 188 scanf("%lld",&y); 189 add(i,x,y); 190 add(x,i,y); 191 } 192 dfs(1,0); 193 sum=n; c[0]=n; ans=0; root=0; 194 getroot(1,0); 195 solve(root); 196 printf("%lld ",ans); 197 return 0; 198 }
做法二:算贡献
预处理出每个点的子树大小和每个点到根的xor和
对于(x,y)这条权值为0的链分两种情况,均在dfs到dfs序靠后的那个点时计算贡献:
1.y在x的子树中,贡献在y时候被算到,贡献为size[y]*(n-size[z]),其中z为x的一个儿子,且y在z的子树中
2.x,y是从他们的lca下来的两条链拼起来的,贡献为size[x]*size[y]
可以看出对于第一种情况,可以用map记一下根到每个点(n-size[z])之和,在切换子树时需要减去(n-size[z])消除贡献
对于第二种情况,记size[u]之和,不需要减去
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 typedef unsigned int uint; 5 typedef unsigned long long ull; 6 typedef pair<int,int> PII; 7 typedef pair<ll,ll> Pll; 8 typedef vector<int> VI; 9 typedef vector<PII> VII; 10 //typedef pair<ll,ll>P; 11 #define N 200010 12 #define M 200010 13 #define fi first 14 #define se second 15 #define MP make_pair 16 #define pb push_back 17 #define pi acos(-1) 18 #define mem(a,b) memset(a,b,sizeof(a)) 19 #define rep(i,a,b) for(int i=(int)a;i<=(int)b;i++) 20 #define per(i,a,b) for(int i=(int)a;i>=(int)b;i--) 21 #define lowbit(x) x&(-x) 22 #define Rand (rand()*(1<<16)+rand()) 23 #define id(x) ((x)<=B?(x):m-n/(x)+1) 24 #define ls p<<1 25 #define rs p<<1|1 26 27 const ll MOD=1e9+7,inv2=(MOD+1)/2; 28 double eps=1e-6; 29 ll INF=1e15; 30 int dx[4]={-1,1,0,0}; 31 int dy[4]={0,0,-1,1}; 32 33 map<ll,ll>mp; 34 int head[N],vet[N],nxt[N],s[N],tot,n; 35 ll a[N],dis[N],ans; 36 37 int read() 38 { 39 int v=0,f=1; 40 char c=getchar(); 41 while(c<48||57<c) {if(c=='-') f=-1; c=getchar();} 42 while(48<=c&&c<=57) v=(v<<3)+v+v+c-48,c=getchar(); 43 return v*f; 44 } 45 46 void add(int a,int b) 47 { 48 nxt[++tot]=head[a]; 49 vet[tot]=b; 50 head[a]=tot; 51 } 52 53 void dfs(int u,int fa) 54 { 55 int e=head[u]; 56 s[u]=1; 57 while(e) 58 { 59 int v=vet[e]; 60 if(v!=fa) 61 { 62 dis[v]=dis[u]^a[v]; 63 dfs(v,u); 64 s[u]+=s[v]; 65 } 66 e=nxt[e]; 67 } 68 } 69 70 void solve(int u,int fa) 71 { 72 ans=(ans+1ll*s[u]*mp[dis[u]]%MOD)%MOD; 73 int e=head[u]; 74 while(e) 75 { 76 int v=vet[e]; 77 if(v!=fa) 78 { 79 mp[dis[u]]=(mp[dis[u]]+n-s[v])%MOD; 80 solve(v,u); 81 mp[dis[u]]=(mp[dis[u]]+s[v]-n+MOD)%MOD; 82 } 83 e=nxt[e]; 84 } 85 mp[dis[u]]=(mp[dis[u]]+s[u])%MOD; 86 } 87 88 int main() 89 { 90 //freopen("1.in","r",stdin); 91 //freopen("1.out","w",stdout); 92 n=read(); 93 rep(i,1,n) head[i]=0; 94 tot=0; 95 rep(i,2,n) 96 { 97 int x=read(); 98 ll y; 99 scanf("%lld",&a[i]); 100 add(x,i); 101 } 102 ans=0; 103 dis[1]=0; 104 dfs(1,0); 105 mp.clear(); 106 solve(1,0); 107 printf("%lld ",ans); 108 return 0; 109 }