学习了树的点分治,树的边分治似乎因为复杂度过高而并不出众,于是没学
自己总结了一下 有些时候面对一些树上的结构 并且解决的是和路径有关的问题的时候 如果是多个询问 关注点在每次给出两个点,求一些关于这两个点之间路径的问题的时候,我们可以使用树链剖分,但是如果是给出一个单一的询问,但是很宏观 类似于求所有点对之间路径满足xx的数量,这时候我们可以树形dp做些什么 但是有时候会遇到一些树形dp难以解决的东西,类似于数组开不下,无法转移状态这种问题,就可以用树分治
树分治基于一个思想 先确定一个点 找到所有过这个点的路径并判断 再对被这个点分开的连通块做同样的操作
QAQ于是做了几道模板题
POJ1741 模板题 求一棵树中 满足点对之间路径加和小于k的数量
这里用到了一个动态规划 判断一个数组中 选两个数能加起来<k 做法是sort后维护两个指针
其余的地方都很模板,需要注意的是要对root的所有son进行solve之后再重置lr
#include<stdio.h> #include<math.h> #include<string.h> #include<vector> #include<queue> #include<map> #include<string> #include<iostream> #include<algorithm> #include<stack> using namespace std; #define L long long #define pb push_back #define lala printf("-------- "); #define ph push #define rep(i, a, b) for (int i=a;i<=b;++i) #define dow(i, b, a) for (int i=b;i>=a;--i) #define fmt(i,n) if(i==n)printf(" ");else printf(" ") ; #define fi first #define se second template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));} int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;} int n , m ; bool vis[10050] ; int dis[10050] ; ///--- struct node { int v,w,nex; }b[10050*2]; int tot ; int head[10050] ; void init() { flc(head,-1); tot = 0 ; } void add(int u,int v,int w) { tot ++ ; b[tot].v=v; b[tot].w=w; b[tot].nex=head[u]; head[u]=tot; } ///--- int son[10050] ; int getsize(int u,int fa) { son[u] = 1 ; for(int i = head[u] ; i != -1 ; i = b[i].nex) { int v = b[i].v ; if(v==fa || vis[v]) continue ; son[u]+=getsize(v,u); } return son[u]; } int minn ; void getroot(int u,int fa,int &root,int siz) { int maxx = siz - son[u] ; for(int i = head[u] ; i != -1 ; i = b[i].nex) { int v=b[i].v ; if(v==fa || vis[v]) continue ; getroot(v,u,root,siz) ; maxx = max(maxx,son[v]) ; } if(minn == -1 || maxx < minn) { minn = maxx ; root = u ; } } ///--- int l , r ; void getdepth(int u,int fa,int xd) { dis[++r] = xd ; for(int i = head[u] ; i != -1 ; i = b[i].nex) { int v = b[i].v ; int w = b[i].w ; if(v == fa || vis[v]) continue ; getdepth(v , u , xd + w) ; } } bool cmp(int a, int b) { return a<b; } int getdep(int l , int r) { if(l >= r) return 0 ; sort(dis + l , dis + r + 1 , cmp ) ; int res = 0 ; int le = l ; int ri = l-1 ; while(ri+1 <= r && dis[ri+1] + dis[le] <= m) { ri ++ ; res ++ ; } while(le + 1 <= r) { le ++ ; while(ri >= l && dis[ri] + dis[le] > m) ri -- ; res += ri - l + 1 ; } for(int i = l ; i <= r ; i ++ ) { if(dis[i]*2 <= m) res -- ; } return (res / 2) ; } ///--- int solve(int u) { int siz = getsize(u , -1) ; minn = -1 ; int root = -1 ; getroot(u , -1 , root , siz) ; vis[root] = true ; int res = 0 ; for(int i = head[root] ; i != -1 ; i = b[i].nex) { int v = b[i].v ; if(vis[v]) continue ; int z = solve(v) ; res += z ; } l = 1 ; r = 0 ; for(int i = head[root] ; i != -1 ; i = b[i].nex) { int v = b[i].v ; int w = b[i].w ; if(vis[v]) continue ; getdepth(v , root , w) ; res -= getdep(l , r) ; l = r + 1 ; } res += getdep(1 , r) ; for(int i = 1 ; i <= r ; i ++ ) { if(dis[i] <= m) res ++ ; else break ; } vis[root] = false ; return res ; } int main () { while(scanf("%d%d" , &n, &m) != EOF) { if(n==0&&m==0) break ; init() ; rep(i,1,n-1) { int u,v,w; u=read();v=read();w=read(); add(u,v,w); add(v,u,w); } memset(vis,false,sizeof(vis)); int ans = solve(1) ; printf("%d " , ans) ; } }
BZOJ 2152 求路径%3==0的点对的数量
因为只是%3 所以比较容易些。。如果写树形dp的话会比较好写 维护一个dp[n][3]的数组就可以
但是如果不是3是很大的数字 就得开dp[n][m] 如果开不下的话就得树分治
/// 树形dp跑得又好又快QAQ
#include<stdio.h> #include<math.h> #include<string.h> #include<vector> #include<queue> #include<map> #include<string> #include<iostream> #include<algorithm> #include<stack> using namespace std; #define L long long #define pb push_back #define lala printf("-------- "); #define ph push #define rep(i, a, b) for (int i=a;i<=b;++i) #define dow(i, b, a) for (int i=b;i>=a;--i) #define fmt(i,n) if(i==n)printf(" ");else printf(" ") ; #define fi first #define se second template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));} int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;} int n ; struct node { int v,w,nex; }b[20050 * 2]; int head[20050]; int tot ; void add(int u,int v,int w) { tot++; b[tot].v=v;b[tot].w=w; b[tot].nex=head[u];head[u]=tot; } void init() { flc(head,-1); tot=0; } int dp[20050][5] ; int ans ; void dfs(int u,int fa) { int a[4]; flc(a,0) ; a[0] = 1 ; for(int i=head[u];i!=-1;i=b[i].nex) { int v=b[i].v ; int w=b[i].w ; if(v==fa) continue ; dfs(v,u) ; rep(j,0,2) { dp[u][(j+w)%3]+=dp[v][j] ; } rep(j,0,2) { int z=w+j; z%=3 ; if(z==0) { ans += a[0]*dp[v][j] ; } if(z==1) { ans += a[2]*dp[v][j] ; } if(z==2) { ans += a[1]*dp[v][j] ; } } rep(j,0,2) { a[(j+w)%3]+=dp[v][j] ; } } dp[u][0] ++ ; } int main (){ while(scanf("%d" , &n) != EOF) { init() ; flc(dp,0); ans = 0 ; rep(i,1,n-1){ int u=read(),v=read(),w=read(); add(u,v,w);add(v,u,w); } dfs(1,-1); int fm = n*n; ans *= 2 ; ans += n ; int gc = __gcd(fm,ans) ; fm/=gc ; ans/=gc ; printf("%d/%d",ans,fm) ; } }
#include<stdio.h> #include<math.h> #include<string.h> #include<vector> #include<queue> #include<map> #include<string> #include<iostream> #include<algorithm> #include<stack> using namespace std; #define L long long #define pb push_back #define lala printf("-------- "); #define ph push #define rep(i, a, b) for (int i=a;i<=b;++i) #define dow(i, b, a) for (int i=b;i>=a;--i) #define rnode(i,u) for(int i = head[u] ; i != -1 ; i = b[i].nex) #define fmt(i,n) if(i==n)printf(" ");else printf(" ") ; #define fi first #define se second template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));} int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;} int n , m ; int ans ; bool vis[20050] ; int dis[20050] ; ///--- struct node { int v,w,nex; }b[20050*2]; int tot ; int head[20050] ; void init() { flc(head,-1); tot = 0 ; } void add(int u,int v,int w) { tot ++ ; b[tot].v=v; b[tot].w=w; b[tot].nex=head[u]; head[u]=tot; } ///--- int siz[20050]; int getsize(int u,int fa) { siz[u] = 1 ; for(int i=head[u];i!=-1;i=b[i].nex) { int v=b[i].v; if(v==fa||vis[v]) continue ; siz[u]+=getsize(v,u); } return siz[u]; } int minn ; void getroot(int u,int fa,int num,int &root) { int maxx=0; for(int i=head[u];i!=-1;i=b[i].nex){ int v=b[i].v ; if(v==fa||vis[v]) continue ; getroot(v,u,num,root); maxx=max(maxx,siz[v]); } maxx=max(maxx,num-siz[u]); if(maxx<minn){ minn=maxx;root=u; } } ///--- int l,r; void getdepth(int u,int fa,int xd) { dis[++r]=xd ; rnode(i,u) { int v=b[i].v ; if(v==fa||vis[v]) continue ; int w=b[i].w ; getdepth(v,u,xd+w) ; } } int getdep(int l,int r) { if(l>r) return 0 ; int a[3] ; flc(a,0) ; rep(i,l,r) { a[dis[i]%3] ++ ; } int res = 0 ; rep(i,0,2) { rep(j,0,2) { if((i+j)%3==0) res += a[i]*a[j] ; } } return res ; } ///--- int solve(int u) { int num = getsize(u,-1); minn = 999999999 ; int root ; getroot(u,-1,num,root); int ans = 0 ; vis[root]=true; rnode(i,root) { int v=b[i].v; if(vis[v]) continue ; ans += solve(v) ; } l = 1 ; r = 0 ; rnode(i,root) { int v=b[i].v; int w=b[i].w; if(vis[v]) continue ; getdepth(v,root,w) ; ans -= getdep(l,r) ; l = r + 1 ; } dis[++r] = 0 ; ans += getdep(1,r) ; vis[root] = false ; return ans ; } int main () { while(scanf("%d" , &n) != EOF) { init() ; rep(i,1,n-1) { int u=read();int v=read() ; int w = read(); add(u,v,w) ; add(v,u,w) ; } memset(vis,false,sizeof(vis)) ; int ans = solve(1) ; int fm = n*n ; int g = __gcd(ans,fm) ; fm/=g ; ans/=g ; printf("%d/%d " , ans , fm) ; } }
HDU 5977 大连的铜牌题 求包含所有颜色的路径的数目 k<=10
这个题的颜色来源于点 在对root的son进行getdepth的时候 需要把root的颜色给带下去 因为我们用ans-root的同一个son内的孩子 里面肯定是包含root的颜色的 这个关系是或 所以可以直接或上去
#include<stdio.h> #include<math.h> #include<string.h> #include<vector> #include<queue> #include<map> #include<string> #include<iostream> #include<algorithm> #include<stack> using namespace std; #define L long long #define pb push_back #define lala printf("-------- "); #define ph push #define rep(i, a, b) for (L i=a;i<=b;++i) #define dow(i, b, a) for (L i=b;i>=a;--i) #define rnode(i,u) for(L i = head[u] ; i != -1 ; i = b[i].nex) #define fmt(i,n) if(i==n)printf(" ");else printf(" ") ; #define fi first #define se second template<class T> inline void flc(T &A, L x){memset(A, x, sizeof(A));} L read(){L x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;} L n , k , m ; L ans ; bool vis[50050] ; L bl[50050] ; ///--- struct node { L v,nex; }b[50050*2]; L tot ; L head[50050] ; L dis[50050] ; void init() { flc(head,-1); tot = 0 ; memset(vis,false,sizeof(vis)) ; } void add(L u,L v) { tot ++ ; b[tot].v=v; b[tot].nex=head[u]; head[u]=tot; } ///--- vector<int>q[2050] ; void thefirst() { L z = (1<<k)-1 ; rep(i,0,1024) q[i].clear() ; rep(i,0,z) { rep(j,0,z) { if( (i|j) == z) { q[i].pb(j) ; } } } } ///--- L siz[50050]; L getsize(L u,L fa) { siz[u] = 1 ; rnode(i,u){ L v=b[i].v; if(v==fa||vis[v]) continue ; siz[u]+=getsize(v,u); } return siz[u]; } L minn ; void getroot(L u,L fa,L num,L &root) { L maxx=0; rnode(i,u){ L v=b[i].v ; if(v==fa||vis[v]) continue ; getroot(v,u,num,root); maxx=max(maxx,siz[v]); } maxx=max(maxx,num-siz[u]); if(minn==-1||maxx<minn){ minn=maxx;root=u; } } ///--- L l , r ; void getdepth(L u,L fa,L xd) { xd |= (1 << (bl[u]-1)) ; dis[++r] = xd ; rnode(i,u) { L v=b[i].v; if(vis[v] || v==fa) continue ; getdepth(v,u,xd) ; } } L mp[2050] ; L getdep(L l , L r) { if(l>r) return 0 ; flc(mp,0) ; L ans = 0 ; rep(i,l,r) { L x=dis[i] ; for(L i=0;i<q[x].size();i++){ L y=q[x][i]; ans += mp[y] ; } mp[x] ++ ; } return ans ; } ///--- L solve(L u) { L siz = getsize(u,-1) ; L root = -1; minn = -1 ; getroot(u,-1,siz,root) ; vis[root]=true ; L ans = 0 ; rnode(i,root) { L v=b[i].v; if(vis[v]) continue ; ans += solve(v) ; } l = 1 ; r = 0 ; rnode(i,root) { L v=b[i].v ; if(vis[v]) continue ; getdepth(v,root,(1<<(bl[root]-1))) ; ans -= getdep(l,r) ; l = r + 1 ; } L x = (1<<(bl[root]-1)) ; L K = (1<<k)-1 ; rep(i,1,r) { if((x | dis[i]) == K) { ans ++ ; } } ans += getdep(1,r) ; vis[root]=false; return ans ; } int main () { while(scanf("%lld%lld" , &n,&k) != EOF) { init() ; thefirst() ; rep(i,1,n) bl[i] = read() ; rep(i,1,n-1) { L u=read(),v=read(); add(u,v) ; add(v,u) ; } if(k == 1) { printf("%lld " , n*n) ; continue ; } L ans = solve(1) ; printf("%lld " , ans*2) ; } }
学会了树分治之后开启了新技能“看见什么不明显DP的树上结构就觉得可以树分治” 感觉要分治算法学傻。。
训练赛看到一个题 感觉树形DP不可做 于是想树分治 发现解决不了这个问题 但是感觉还是树分治 赛后发现果然
uvaLive 6900 给出一棵树 每条边有cost与val 我有C 在树上选一条路径出来 使sum(cost) <= C时的最大val
这个和加减不太一样 因为加减是可以通过对root的son来操作进行去重的 上一个大连的是进行或运算 也无可厚非 但是这个求max 是不可逆的
但是我们本来就不需要去重 和以前模板思路不一样的是 我们保存dis数组中 每一个值来自哪个root的儿子R 然后对R排序 处理完一个R再搞另一个R 我们不需要排序 因为根据dfs的特性 相同的R一定有且只有一段 所以不需要sort 和之前的去重没有什么时间上的差别 因为省去了去重的时间 所以我想 时间应该会更快
在第一道题里面 用一个sort+O(n)单调思想 其实sort就撑到nlogn了 所以之后的nlogn也是可以接受的 可以做一个树状数组 来维护前缀max
因为不能开太大 所以进行一个离散化 时间也是nlogn的 最后的复杂度还是nlognlogn 虽然常数大点
这种思想是泛用的 之前的几道题也可以这么做
uvaLive 6900
#include<stdio.h> #include<math.h> #include<string.h> #include<vector> #include<queue> #include<map> #include<string> #include<iostream> #include<algorithm> #include<stack> using namespace std; #define L long long #define pb push_back #define lala printf("-------- "); #define ph push #define rep(i, a, b) for (int i=a;i<=b;++i) #define dow(i, b, a) for (int i=b;i>=a;--i) #define fmt(i,n) if(i==n)printf(" ");else printf(" ") ; #define rnode(i,u) for(int i = head[u] ; i != -1 ; i = b[i].nex) #define fi first #define se second template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));} int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;} int n , m; struct node { int vol,val; int R ; }dis[20050]; bool vis[20050]; struct no { int v,vol,val,nex; }b[20050*2]; int head[20050]; int tot; void init() { flc(head,-1); tot=0; } void add(int u,int v,int vol,int val) { tot++ ; b[tot].v=v;b[tot].vol=vol;b[tot].val=val; b[tot].nex=head[u] ; head[u]=tot; } int V ; ///--- int son[20050] ; int getsize(int u,int fa) { son[u] = 1 ; rnode(i,u) { int v=b[i].v; if(v==fa || vis[v]) continue ; son[u] += getsize(v,u) ; } return son[u] ; } int minn ; void getroot(int u,int fa,int &root,int siz) { int maxx = siz - son[u] ; rnode(i,u) { int v=b[i].v ; if(v==fa || vis[v]) continue ; if(son[v] > maxx) maxx = son[v] ; getroot(v,u,root,siz) ; } if(maxx < minn) { minn = maxx; root = u ; } } ///--- int l,r ; void getdepth(int u,int fa,int xdvol,int xdval,int sp) { node tmp ; tmp.vol = xdvol ; tmp.val = xdval ; tmp.R = sp ; dis[++r] = tmp ; rnode(i,u) { int v = b[i].v ; if(v == fa || vis[v]) continue ; getdepth(v,u,xdvol + b[i].vol,xdval + b[i].val,sp) ; } } int c[40050] ; int lowbit(int x) { return (x&(-x)) ; } void segadd(int x,int ma) { while(x<=40000) { c[x]=max(c[x] , ma) ; x+=lowbit(x) ; } } int fin(int x) { int res = 0 ; while(x>0) { res = max(c[x],res); x-=lowbit(x) ; } return res ; } int calc(int l,int r) { if(l > r) return 0 ; flc(c,0) ; vector<int>ls ; ls.clear() ; rep(i,l,r) { ls.pb(dis[i].vol) ; } int res = 0 ; sort(ls.begin(),ls.end()) ; ls.erase(unique(ls.begin(),ls.end()) , ls.end()) ; for(int i = l ; i <= r ; i ++ ) { int j = i ; while(j <= r && dis[j].R == dis[i].R) { int z = dis[j].vol ; int val1 = dis[j].val ; if(z > V) { j ++ ; continue ; } int x = V - z ; int id = -2 ; int ll = 0 ; int rr = ls.size()-1 ; while(ll<=rr) { int mid=(ll+rr)/2 ; if(ls[mid]<=x) { id=mid; ll=mid+1; } else { rr=mid-1; } } if(id==-2){ j++; continue ; } int rres = fin(id+1) ; res = max(res , rres + dis[j].val) ; j ++ ; } j -- ; rep(k,i,j) { int vol = dis[k].vol ; int val = dis[k].val ; int id = lower_bound(ls.begin(),ls.end(),vol)-ls.begin()+1 ; segadd(id,val) ; } i = j ; } return res ; } ///--- int solve(int u) { int siz = getsize(u,-1) ; minn = 999999999 ; int root ; getroot(u,-1,root,siz) ; vis[root] = true ; int res = 0 ; rnode(i,root) { int v=b[i].v ; if(vis[v]) continue ; int x = solve(v) ; res = max(res,x) ; } l = 1 ; r = 0 ; rnode(i,root) { int v = b[i].v ; int vol = b[i].vol ; int val = b[i].val ; if(vis[v]) continue ; getdepth(v,root,vol,val,v) ; } res = max(res , calc(1,r)) ; rep(i,1,r) { if(dis[i].vol <= V) { res = max(dis[i].val , res) ; } } vis[root] = false ; return res ; } int main () { int t = read(); while(t -- ) { n = read(); init() ; rep(i,2,n) { int u=read(),v=read(),vol=read(),val=read() ; add(u,v,vol,val); add(v,u,vol,val); } V = read() ; memset(vis,false,sizeof(vis)); int ans=solve(1) ; printf("%d " , ans) ; } }
BZOJ 2152 用这种方法改了一下 发现由于必须sort 所以复杂度比之前的做法要多一个log
#include<stdio.h> #include<math.h> #include<string.h> #include<vector> #include<queue> #include<map> #include<string> #include<iostream> #include<algorithm> #include<stack> using namespace std; #define L long long #define pb push_back #define lala printf("-------- "); #define ph push #define rep(i, a, b) for (int i=a;i<=b;++i) #define dow(i, b, a) for (int i=b;i>=a;--i) #define rnode(i,u) for(int i = head[u] ; i != -1 ; i = b[i].nex) #define fmt(i,n) if(i==n)printf(" ");else printf(" ") ; #define fi first #define se second template<class T> inline void flc(T &A, int x){memset(A, x, sizeof(A));} int read(){int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;} int n , m ; int ans ; bool vis[20050] ; struct no { int x , R; }dis[20050] ; ///--- struct node { int v,w,nex; }b[20050*2]; int tot ; int head[20050] ; void init() { flc(head,-1); tot = 0 ; } void add(int u,int v,int w) { tot ++ ; b[tot].v=v; b[tot].w=w; b[tot].nex=head[u]; head[u]=tot; } ///--- int siz[20050]; int getsize(int u,int fa) { siz[u] = 1 ; for(int i=head[u];i!=-1;i=b[i].nex) { int v=b[i].v; if(v==fa||vis[v]) continue ; siz[u]+=getsize(v,u); } return siz[u]; } int minn ; void getroot(int u,int fa,int num,int &root) { int maxx=0; for(int i=head[u];i!=-1;i=b[i].nex){ int v=b[i].v ; if(v==fa||vis[v]) continue ; getroot(v,u,num,root); maxx=max(maxx,siz[v]); } maxx=max(maxx,num-siz[u]); if(maxx<minn){ minn=maxx;root=u; } } ///--- int l,r; void getdepth(int u,int fa,int xd,int sp) { no tmp ; tmp.x = xd ; tmp.R = sp ; dis[++r] = tmp ; rnode(i,u) { int v=b[i].v ; if(v==fa||vis[v]) continue ; int w=b[i].w ; getdepth(v,u,xd+w,sp) ; } } int getdep(int l,int r) { if(l>r) return 0; int res = 0 ; int a[5] ; flc(a,0) ; for(int i = l ; i <= r ; i ++ ) { int j = i ; while(j <= r && dis[j].R==dis[i].R) { int x = dis[j].x % 3 ; int ned = 3 - x ; ned %= 3 ; res += a[ned] ; j ++ ; } j -- ; rep(k,i,j) { int x = dis[k].x % 3 ; a[x] ++ ; } i = j ; } return res ; } ///--- int solve(int u) { int num = getsize(u,-1); minn = 999999999 ; int root ; getroot(u,-1,num,root); int ans = 0 ; vis[root]=true; rnode(i,root) { int v=b[i].v; if(vis[v]) continue ; ans += solve(v) ; } l = 1 ; r = 0 ; rnode(i,root) { int v=b[i].v; int w=b[i].w; if(vis[v]) continue ; getdepth(v,root,w,v) ; } ans += getdep(1,r) ; rep(i,1,r) { if(dis[i].x % 3 == 0) ans ++ ; } vis[root] = false ; return ans ; } int main () { while(scanf("%d" , &n) != EOF) { init() ; rep(i,1,n-1) { int u=read();int v=read() ; int w = read(); add(u,v,w) ; add(v,u,w) ; } memset(vis,false,sizeof(vis)) ; int ans = solve(1) ; ans *= 2; ans += n ; int fm = n*n ; int g = __gcd(ans,fm) ; fm/=g ; ans/=g ; printf("%d/%d " , ans , fm) ; } }