题意:
给定一棵n个点和有向边构成的树,其中一些边是合法边,一些边是非法边,
经过非法边需要1的费用,并且经过之后费用翻倍。
给定一个长为m的序列,问从点1开始按顺序移动到序列中对应点的总费用。
1<=n<=10^5,
1<=m<=10^6
题解:
还是比较水的…
正解是各种方法求LCA,在点上打标记,最后DFS一遍就可以得到答案。
用tarjan求LCA可以做到总复杂度O(n*α)…
我傻傻地见树就剖,强行O(n log n log n)碾过去了…
每次把起点终点之间的路径的经过次数加一,最后统计非法边对应的点,
对答案的贡献是 2^(次数)-1 。
ZKW线段树的常数还是比较可以接受的…虽然Codeforces机子本来就快…
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 #include <cstdio> #include <cstring> #define fore(p) for(int pt=h[p];pt;pt=e[pt].nx) typedef long long lint; const int N = 100010, MO = 1000000007; inline int read() { int s = 0; char c; while((c=getchar())<'0'||c>'9'); do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9'); return s; } int n,q,aa,bb,tot,ttot,cur,tg,curd,tim,ans,ql,qr,S,p2[1000010],h[N],top[N],d[N],hs[N],f[N],iw[N]; bool il[N],il2[N],qv; struct eg{int dt,nx;bool le;}e[N*2]; struct segt { int tr[N*3]; int query(int p){ int s = 0; for(int i=S+p;i>=1;i>>=1) s += tr[i]; return s; } void db(int l,int r) { for(l=l+S-1,r=r+S+1;l^r^1;l>>=1,r>>=1) { if(~l&1) tr[l^1]++; if( r&1) tr[r^1]++; } } }tr1,tr2; inline void link(int b) { e[++tot].nx = h[aa]; e[tot].dt = bb; e[tot].le = 1; h[aa] = tot; e[++tot].nx = h[bb]; e[tot].dt = aa; e[tot].le = b; h[bb] = tot; } int dfs1(int p,int ff) { f[p] = ff, d[p] = ++curd; int sz = 1,nx,t,mx=0; fore(p) { if((nx=e[pt].dt)==ff) continue; if(e[pt^1].le) il[nx] = 1; if(e[pt].le) il2[nx] = 1; t = dfs1(nx,p); if(t>mx) mx = t, hs[p] = nx; } curd--; return sz; } void dfs2(int p,int tp) { top[p] = tp; iw[p] = ++tim; if(hs[p]) dfs2(hs[p],tp); fore(p) if(e[pt].dt!=f[p]&&e[pt].dt!=hs[p]) dfs2(e[pt].dt,e[pt].dt); } void calc(int aa,int bb) { if(aa==bb) return; while(top[aa]!=top[bb]) { if(d[top[aa]]>d[top[bb]]) tr1.db(iw[top[aa]],iw[aa]), aa = f[top[aa]]; else tr2.db(iw[top[bb]],iw[bb]), bb = f[top[bb]]; } if(d[aa]>d[bb]) tr1.db(iw[bb]+1,iw[aa]); else tr2.db(iw[aa]+1,iw[bb]); } int main() { int i,j; n = read(); for(S=1;S<=n+2;S<<=1); for(i=2,tot=1;i<=n;i++) aa = read(), bb = read(), link(!read()); dfs1(1,0); dfs2(1,1); q = read(); for(p2[0]=1,i=1;i<=q;i++) p2[i] = ((lint)p2[i-1]<<1ll)%MO; for(cur=1;q--;cur=tg) tg = read(), calc(cur,tg); for(i=1;i<=n;i++) { if(!il[i]) ans = ((lint)ans+p2[tr1.query(iw[i])]+MO-1)%MO; if(!il2[i]) ans = ((lint)ans+p2[tr2.query(iw[i])]+MO-1)%MO; } printf("%d ",ans); return 0; }
补上O(n*α)的做法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 #include <cstdio> #include <cstring> #define fore(p) for(int pt=h[p];pt;pt=e[pt].nx) typedef long long lint; const int N = 100010, MO = 1000000007; inline int read() { int s = 0; char c; while((c=getchar())<'0'||c>'9'); do{s=s*10+c-'0';}while((c=getchar())>='0'&&c<='9'); return s; } int n,q,aa,bb,tot,qtot,cur,tg,curd,tim,ans,p2[1000010],h[N],qh[N],rp[N],f[N],w[N],w2[N],mx; bool il[N],il2[N],b[N],qv; struct eg{int dt,nx;bool le;}e[N*2]; struct qu{int dt,nx;}qs[N*20]; inline void link(int b) { e[++tot].nx = h[aa]; e[tot].dt = bb; e[tot].le = 1; h[aa] = tot; e[++tot].nx = h[bb]; e[tot].dt = aa; e[tot].le = b; h[bb] = tot; } inline void linkq() { if(aa==bb) return; qs[++qtot].nx = qh[aa]; qs[qtot].dt = bb; qh[aa] = qtot; qs[++qtot].nx = qh[bb]; qs[qtot].dt = aa; qh[bb] = qtot; w[aa]++; w2[bb]++; } int findf(int p){ return f[p]==p?p:(f[p]=findf(f[p])); } void dfs(int p) { b[p] = 1; for(int nx,pt=qh[p];pt;pt=qs[pt].nx) if(b[qs[pt].dt]) nx = findf(qs[pt].dt), w[nx]--, w2[nx]--; for(int nx,pt=h[p];pt;pt=e[pt].nx) if(!b[nx=e[pt].dt]) { if(e[pt^1].le) il[nx] = 1; if(e[pt].le) il2[nx] = 1; dfs(nx); f[nx] = p; w[p] += w[nx]; w2[p] += w2[nx]; } if(w[p]>mx) mx = w[p]; if(w2[p]>mx) mx = w2[p]; } int main() { int i,j; for(n=read(),i=2,tot=1,f[1]=1;i<=n;i++) aa = read(), bb = read(), link(!read()), f[i] = i; for(q=read(),i=1,qtot=1,aa=1;i<=q;i++,aa=bb) bb = read(), linkq(); dfs(1); for(p2[0]=1,i=1;i<=mx;i++) p2[i] = ((lint)p2[i-1]<<1ll)%MO; for(i=1;i<=n;i++) { if(!il[i]) ans = ((lint)ans+p2[w[i]]-1)%MO; if(!il2[i]) ans = ((lint)ans+p2[w2[i]]-1)%MO; } printf("%d ",ans); return 0; }