题意:给定一棵树,初始边权为零,对其进行以下操作:
1. 把树上a到b的路径上的边权:0变为1,1变为0。
2. 与树上a到b的路径相邻的边(只有一个公共点),把它们的边权:0变为1,1变为0。
3. 查询a到b的路径上的边权和。
思路:
操作1用普通的重链剖分就可以解决。
对于操作2,可以看出,在每次从重链的底部跳到重链顶端的过程中,除重链底部和重链顶端可能连一条相邻的(不是底部到顶端路径上的)重边外,其余的相邻边都是轻边。所以我们可以再开一个线段树记录与树上的每个点是否被标记,如果被标记,说明与它相邻的轻边也被标记。
对于操作3,路径上的所有重边可以通过第一个线段树(存边权和那个)解决,路径上的轻边则还取决于:它的两个端点是否被标记。
代码:
1 #include<bits/stdc++.h> 2 #define lson l,mid,rt<<1 3 #define rson mid+1,r,rt<<1|1 4 using namespace std; 5 const int N=1e5+10; 6 const int M=2e5+10; 7 const int inf=0x3f3f3f3f; 8 9 int head[N],to[M],nxt[M]; 10 int tot; 11 int dep[N],fa[N],siz[N],son[N]; 12 int top[N],id[N]; 13 int cnt; 14 int n; 15 int sum[N<<2],mark[N<<2],tag1[N<<2],tag2[N<<2]; 16 17 inline void add(int u,int v){ 18 to[++tot]=v,nxt[tot]=head[u],head[u]=tot; 19 } 20 21 void dfs1(int u,int f){ 22 dep[u]=dep[f]+1; 23 fa[u]=f; 24 siz[u]=1; 25 son[u]=0; 26 int maxs=-1; 27 for(int i=head[u];i;i=nxt[i]){ 28 int v=to[i]; 29 if(v==f) continue; 30 dfs1(v,u); 31 siz[u]+=siz[v]; 32 if(siz[v]>maxs) maxs=siz[v],son[u]=v; 33 } 34 } 35 36 void dfs2(int u,int topf){ 37 id[u]=++cnt; 38 top[u]=topf; 39 if(!son[u]) return; 40 dfs2(son[u],topf); 41 for(int i=head[u];i;i=nxt[i]){ 42 int v=to[i]; 43 if(v==fa[u]||v==son[u]) continue; 44 dfs2(v,v); 45 } 46 } 47 48 void build(int l,int r,int rt){ 49 sum[rt]=mark[rt]=tag1[rt]=tag2[rt]=0; 50 if(l==r) return; 51 int mid=(l+r)>>1; 52 build(lson); 53 build(rson); 54 } 55 56 inline void push_up(int rt){ 57 sum[rt]=sum[rt<<1]+sum[rt<<1|1]; 58 } 59 60 void push_down1(int rt,int m){ 61 if(tag1[rt]){ 62 sum[rt<<1]=m-(m>>1)-sum[rt<<1]; 63 tag1[rt<<1]^=1; 64 sum[rt<<1|1]=(m>>1)-sum[rt<<1|1]; 65 tag1[rt<<1|1]^=1; 66 tag1[rt]=0; 67 } 68 } 69 70 void update1(int x,int y,int l,int r,int rt){ 71 if(x<=l&&y>=r){ 72 sum[rt]=(r-l+1)-sum[rt]; 73 tag1[rt]^=1; 74 return; 75 } 76 push_down1(rt,r-l+1); 77 int mid=(l+r)>>1; 78 if(x<=mid) update1(x,y,lson); 79 if(y>mid) update1(x,y,rson); 80 push_up(rt); 81 } 82 83 int query1(int x,int y,int l,int r,int rt){ 84 if(x<=l&&y>=r) return sum[rt]; 85 push_down1(rt,r-l+1); 86 int mid=(l+r)>>1; 87 int ans=0; 88 if(x<=mid) ans+=query1(x,y,lson); 89 if(y>mid) ans+=query1(x,y,rson); 90 return ans; 91 } 92 93 void push_down2(int rt){ 94 if(tag2[rt]){ 95 mark[rt<<1]^=1; 96 mark[rt<<1|1]^=1; 97 tag2[rt<<1]^=1; 98 tag2[rt<<1|1]^=1; 99 tag2[rt]=0; 100 } 101 } 102 103 void update2(int x,int y,int l,int r,int rt){ 104 if(x<=l&&y>=r){ 105 mark[rt]^=1; 106 tag2[rt]^=1; 107 return; 108 } 109 push_down2(rt); 110 int mid=(l+r)>>1; 111 if(x<=mid) update2(x,y,lson); 112 if(y>mid) update2(x,y,rson); 113 } 114 115 int query2(int x,int l,int r,int rt){ 116 if(l==r) return mark[rt]; 117 push_down2(rt); 118 int mid=(l+r)>>1; 119 if(x<=mid) return query2(x,lson); 120 return query2(x,rson); 121 } 122 123 void uprange1(int x,int y){ 124 while(top[x]!=top[y]){ 125 if(dep[top[x]]<dep[top[y]]) swap(x,y); 126 update1(id[top[x]],id[x],1,n,1); 127 x=fa[top[x]]; 128 } 129 if(x==y) return; 130 if(dep[x]>dep[y]) swap(x,y); 131 update1(id[x]+1,id[y],1,n,1); 132 } 133 134 void uprange2(int x,int y){ 135 while(top[x]!=top[y]){ 136 if(dep[top[x]]<dep[top[y]]) swap(x,y); 137 update2(id[top[x]],id[x],1,n,1); 138 if(son[x]) update1(id[son[x]],id[son[x]],1,n,1); 139 x=fa[top[x]]; 140 } 141 if(dep[x]>dep[y]) swap(x,y); 142 if(son[y]) update1(id[son[y]],id[son[y]],1,n,1); 143 if(son[fa[x]]==x) update1(id[x],id[x],1,n,1); 144 update2(id[x],id[y],1,n,1); 145 } 146 147 int qrange(int x,int y){ 148 int ans=0; 149 while(top[x]!=top[y]){ 150 if(dep[top[x]]<dep[top[y]]) swap(x,y); 151 if(x!=top[x]) 152 ans+=query1(id[top[x]]+1,id[x],1,n,1); 153 ans+=query1(id[top[x]],id[top[x]],1,n,1)^query2(id[fa[top[x]]],1,n,1)^query2(id[top[x]],1,n,1); 154 x=fa[top[x]]; 155 } 156 if(x==y) return ans; 157 if(dep[x]>dep[y]) swap(x,y); 158 ans+=query1(id[x]+1,id[y],1,n,1); 159 return ans; 160 } 161 162 int main() 163 { 164 int T; 165 scanf("%d",&T); 166 while(T--){ 167 tot=cnt=0; 168 memset(head,0,sizeof(head)); 169 scanf("%d",&n); 170 for(int i=1;i<n;i++){ 171 int u,v; 172 scanf("%d%d",&u,&v); 173 add(u,v); 174 add(v,u); 175 } 176 dfs1(1,0); dfs2(1,1); 177 build(1,n,1); 178 int q; 179 scanf("%d",&q); 180 while(q--){ 181 int t,x,y; 182 scanf("%d%d%d",&t,&x,&y); 183 if(t==1) uprange1(x,y); 184 else if(t==2) uprange2(x,y); 185 else printf("%d ",qrange(x,y)); 186 } 187 } 188 return 0; 189 }