树上启发式合并,一种美妙的黑科技,可以用普通的优化让你$n^2$变成严格$n log$,解决一些类似(树上数颜色,树上查众数)这样的问题
首先你要知道暴力为什么是$n^2$的
以这个图为例
每次你从一个节点开始向下搜,你从1节点搜到3,搜完这个子树然后你需要把3存的col等信息删去再遍历另一个子树才是正确的
那么我们每次遍历这个节点一个子树,每次搜完这棵子树都要清空当前子树储存信息这样(最差)复杂度$n^2$
我们可以发现清空最后一个遍历的子树是没有意义的,那么我们人为把最后一个子树放到最后不就是最优的吗
所以,首先我们先找出来重链,轻链,对于轻链我们求出子树答案,再清除子树贡献,.然后求出重链上子树答案,不清除贡献.最后我们再算一遍子树对当前节点贡献即可
你可能会认为,这不就是一个简单的优化吗,怎么就是$n log$了
我不知道
它并没有优化最优复杂度而是避免了最差复杂度
以给一棵根为1的树,每次询问子树颜色种类数为例
代码大致如下
#include<bits/stdc++.h> using namespace std; #define ll int #define r register #define A 1001010 ll head[A],nxt[A],ver[A],size[A],col[A],cnt[A],ans[A],son[A]; ll tot=0,num,sum,nowson,n,m,xx,yy; inline void add(ll x,ll y){ nxt[++tot]=head[x],head[x]=tot,ver[tot]=y; } inline ll read(){ ll f=1,x=0;char c=getchar(); while(!isdigit(c)){ if(c=='-') f=-1; c=getchar(); } while(isdigit(c)) x=(x<<1)+(x<<3)+(c^48),c=getchar(); return f*x; } void dfs(ll x,ll fa){ size[x]=1; for(ll i=head[x];i;i=nxt[i]){ ll y=ver[i]; if(y==fa) continue; dfs(y,x); size[x]+=size[y]; if(size[son[x]]<size[y]) son[x]=y; } } void cal(ll x,ll fa,ll val){ if(!cnt[col[x]]) ++sum; cnt[col[x]]+=val; for(ll i=head[x];i;i=nxt[i]){ ll y=ver[i]; if(y==fa||y==nowson) continue; cal(y,x,val); } } void dsu(ll x,ll fa,bool op){ for(ll i=head[x];i;i=nxt[i]){ ll y=ver[i]; if(y==fa||y==son[x]) continue; dsu(y,x,0); //从轻儿子出发 } if(son[x]) dsu(son[x],x,1),nowson=son[x]; cal(x,fa,1);nowson=0; ans[x]=sum; if(!op){ cal(x,fa,-1); sum=0; } } int main(){ n=read(); for(ll i=1;i<=n-1;i++){ xx=read(),yy=read(); add(xx,yy),add(yy,xx); } for(ll i=1;i<=n;i++) col[i]=read(); dfs(1,0); dsu(1,0,1); m=read(); for(ll i=1;i<=m;i++){ xx=read(); printf("%d ",ans[xx]); } }
另一种打法
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> using namespace std; #define R register #define ll long long inline ll read(){ ll aa=0;R int bb=1;char cc=getchar(); while(cc<'0'||cc>'9') {if(cc=='-')bb=-1;cc=getchar();} while(cc>='0'&&cc<='9') {aa=(aa<<1)+(aa<<3)+(cc^48);cc=getchar();} return aa*bb; } const int N=1e5+3; struct edge{ int v,last; }ed[N<<1]; int first[N],tot; inline void add(int x,int y) { ed[++tot].v=y; ed[tot].last=first[x]; first[x]=tot; } int n,m,c[N],son[N],cnt[N],ans[N],siz[N]; void dfsi(int x,int fa) { siz[x]=1; for(R int i=first[x],v;i;i=ed[i].last){ v=ed[i].v; if(v==fa)continue; dfsi(v,x); siz[x]+=siz[v]; if(siz[v]>siz[son[x]])son[x]=v; } return; } int dfsj(int x,int fa,int bs,int kep) { if(kep){ for(R int i=first[x],v;i;i=ed[i].last){ v=ed[i].v; if(v!=fa&&v!=son[x]) dfsj(v,x,0,1); } } int res=0; if(son[x])res+=dfsj(son[x],x,1,kep); for(R int i=first[x],v;i;i=ed[i].last){ v=ed[i].v; if(v!=fa&&v!=son[x]) res+=dfsj(v,x,0,0); } if(!cnt[c[x]])res++; cnt[c[x]]++; if(kep){ ans[x]=res; if(!bs)memset(cnt,0,sizeof(cnt)); } return res; } int main() { n=read(); for(R int i=1,x,y;i<n;++i){ x=read();y=read(); add(x,y);add(y,x); } for(R int i=1;i<=n;++i)c[i]=read(); dfsi(1,0); dfsj(1,0,1,1); m=read(); for(R int i=1,x;i<=m;++i){ x=read(); printf("%d ",ans[x]); } return 0; }
虽然好像没什么区别
然后再看一道例题
有一棵 n 个节点的以 1 号节点为根的树,每个节点上有一个小桶,节点u上的小桶可以容纳${k_u}$ 个小球,ljh每次可以给一个节点到根路径上的所有节点的小桶内放一个小球,如果这个节点的小桶满了则不能放进这个节点,最后多次询问某个节点值
首先暴力不能过
直接权值线段树+线段树合并很难维护,树链剖分也难以维护,但我们直接树上启发式合并+线段树暴力修改可以维护。
首先单纯线段树暴力修改可以维护,但这会超时。于是我们用启发式合并作为时间复杂度保证,莫名奇妙AC了这个题
#include<bits/stdc++.h> using namespace std; #define ll long long #define A 1001010 ll head[A],nxt[A],ver[A],size[A],son[A],tong[A],col[A],getfa[A],isbigson[A],ans[A],al[A]; vector<pair<ll,ll> >v[A]; map<ll,ll>mp; ll n,m,tot=0,Q,wwb=0; struct tree{ ll l,r,f,x,t,c; }tr[A]; void add(ll x,ll y){ nxt[++tot]=head[x],head[x]=tot,ver[tot]=y; } void prdfs(ll x,ll fa){ size[x]=v[x].size()+1; for(ll i=head[x];i;i=nxt[i]){ ll y=ver[i]; if(y==fa) continue; prdfs(y,x); size[x]+=size[y]; if(size[son[x]]<size[y]) isbigson[son[x]]=0,son[x]=y,isbigson[y]=1; } } void built(ll p,ll l,ll r){ tr[p].l=l,tr[p].r=r; if(tr[p].l==tr[p].r){ return ; } ll mid=(l+r)>>1; built(p<<1,l,mid); built(p<<1|1,mid+1,r); } ll ask(ll p,ll pos){ if(pos>=tr[p].t) return tr[p].c; return (pos>=tr[p<<1].t?tr[p<<1].c+ask(p<<1|1,pos-tr[p<<1].t):ask(p<<1,pos)); } void insert(ll p,ll pos,ll t,ll c){ if(tr[p].l==tr[p].r) {tr[p].t+=t;tr[p].c+=c;return;} if(pos<=tr[p<<1].r) insert(p<<1,pos,t,c); else insert(p<<1|1,pos,t,c); tr[p].t=tr[p<<1].t+tr[p<<1|1].t; tr[p].c=tr[p<<1].c+tr[p<<1|1].c; } void up(ll x,ll fa){ if(v[getfa[x]].size()<v[getfa[fa]].size()){ for(ll i=0;i<v[getfa[x]].size();i++) v[getfa[fa]].push_back(v[getfa[x]][i]); v[getfa[x]].clear(); getfa[x]=getfa[fa]; } else{ for(ll i=0;i<v[getfa[fa]].size();i++) v[getfa[x]].push_back(v[getfa[fa]][i]); v[getfa[fa]].clear(); getfa[fa]=getfa[x]; } } void dfs(ll x,ll fa){ for(ll i=head[x];i;i=nxt[i]){ ll y=ver[i]; if(y==fa||y==son[x]) continue; dfs(y,x); } if(son[x]) dfs(son[x],x); for(ll i=0;i<v[getfa[x]].size();i++){ ll tim=v[getfa[x]][i].first,col=v[getfa[x]][i].second; if(!al[col]) al[col]=tim,insert(1,tim,1,1); else if(al[col]>tim){ insert(1,al[col],0,-1); insert(1,tim,1,1); al[col]=tim; } else insert(1,tim,1,0); } // printf("t=%lld tong=%lld ",tr[1].t,tong[x]); ans[x]=ask(1,min(tr[1].t,tong[x])); if(son[x]) up(son[x],x); if(!isbigson[x]){ for(ll i=0;i<v[getfa[x]].size();i++){ ll tim=v[getfa[x]][i].first,col=v[getfa[x]][i].second; if(al[col]==tim) insert(1,tim,-1,-1),al[col]=0; else insert(1,tim,-1,0); } up(x,fa); } /* for(ll i=1;i<=5;i++){ printf("ans=%lld ",ans[i]); } *//* cout<<endl;*/ } int main(){ scanf("%lld",&n); for(ll i=1;i<n;i++){ ll xx,yy; scanf("%lld%lld",&xx,&yy); add(xx,yy),add(yy,xx); } for(ll i=1;i<=n;i++){ scanf("%lld",&tong[i]); getfa[i]=i; } prdfs(1,0); scanf("%lld",&m);built(1,1,m); for(ll i=1,x,c;i<=m;i++){ scanf("%lld%lld",&x,&c); if(!mp[c]) mp[c]=++wwb; //离散化 v[x].push_back(make_pair(i,mp[c])); } dfs(1,0); scanf("%lld",&Q); for(ll i=1,x;i<=Q;i++){ scanf("%lld",&x); printf("%lld ",ans[x]); } }