zoukankan      html  css  js  c++  java
  • [JZOJ4331] 【清华集训模拟】树

    题目

    题目大意

    给你一棵带点权的树,求将树变成一堆不相交的链,而且这些链的权值和非负的方案数。


    正解

    显然这道题是个(DP)
    首先求个前缀和(sum)
    为了后面讲述方便,我这样设:(f_{i,j})表示以(i)为根的子树,其中某条链从(x)伸出到(i)的方案数,而且(sum_x=j)
    还有设(g_i)表示以(i)为根的,没有伸出去的链的方案数。
    显然有这样的转移:

    [prod g_i o f_{x,sum_x}\ f_{y,j}prod_{i eq y} g_i o f_x,j]

    [f_{x,j} o g_x (j-sum_{fa_x}geq 0)\ f_{y,j}f_{z,k}prod_{i eq y,i eq z}g_i o g_x (j+k-2sum_x+a_xgeq 0) ]

    如果直接这样搞肯定会爆炸。所以考虑用线段树来维护(f)
    由于可能会出现(g)值为(0)的情况,所以不能直接用逆元来搞。
    要维护个前缀积和后缀积。
    首先要求出重儿子,把重儿子作为第一个儿子,然后线段树合并之前也启发式合并。
    具体来说,我们钦定(j<k)。在合并的时候(设前面子树合并出来的线段树为(A),这个线段树为(B))当前的儿子作为(k),遍历(B)的所有叶子节点,并在(A)中区间询问。这时候记得要乘上后缀积。将询问出来的东西加在(g_x)中。
    然后两个合并在一起。记得在合并之前,整个(A)乘子树的(g_x),整个(B)乘前缀积
    。搞完这个再合并。
    最后你就会愉快地发现,所有子树合并之后就是上面第二行式子。这样只需要把第一行的加进去。第四行的式子已经计算完了,只需要再加上第一个式子就可以了。

    然而这不是题解的做法,作为一个小蒟蒻,表示看不懂题解……


    代码

    using namespace std;
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <cassert>
    #define N 100010
    #define INF 1000000000
    #define mo 1000000007
    int n;
    int a[N];
    struct EDGE{
    	int to;
    	EDGE *las;
    } e[N*2];
    int ne;
    EDGE *last[N];
    struct Node *null;
    struct Node{
    	Node *l,*r;
    	int sum,tag;
    	inline void get_tag(int c){sum=(long long)sum*c%mo;tag=(long long)tag*c%mo;}
    	inline void pushdown(){
    		if (tag!=1){
    			l->get_tag(tag);
    			r->get_tag(tag);
    			tag=1;
    		}
    	}
    	inline void update(){sum=l->sum+r->sum;sum>=mo?sum-=mo:0;}
    } d[N*40];
    int cnt;
    Node *rt[N];
    inline Node *newnode(){return &(d[++cnt]={null,null,0,1});}
    void add(Node *&t,int l,int r,int x,int c){
    	if (t==null)
    		t=newnode();
    	if (l==r){
    		(t->sum+=c)%=mo;
    		return;
    	}
    	t->pushdown();
    	int mid=l+r>>1;
    	if (x<=mid)
    		add(t->l,l,mid,x,c);
    	else
    		add(t->r,mid+1,r,x,c);
    	t->update();
    }
    int query(Node *t,int l,int r,int st,int en){
    	if (t==null)
    		return 0;
    	if (st<=l && r<=en)
    		return t->sum;
    	t->pushdown();
    	int mid=l+r>>1,res=0;
    	if (st<=mid)
    		res+=query(t->l,l,mid,st,en);
    	if (mid<en)
    		res+=query(t->r,mid+1,r,st,en);
    	return res>=mo?res-mo:res;
    }
    Node *merge(Node *a,Node *b){
    	if (a==null)
    		return b;
    	if (b==null)
    		return a;
    	a->pushdown(),b->pushdown();
    	a->l=merge(a->l,b->l);
    	a->r=merge(a->r,b->r);
    	a->sum+=b->sum;
    	a->sum>=mo?a->sum-=mo:0;
    	return a;
    }
    int calc(Node *t,int l,int r,Node *rt,int bor){
    	if (t==null)
    		return 0;
    	if (l==r)
    		return (long long)query(rt,-INF,INF,bor-l,INF)*t->sum%mo;
    	t->pushdown();
    	int mid=l+r>>1,res=0;
    	res+=calc(t->l,l,mid,rt,bor);
    	res+=calc(t->r,mid+1,r,rt,bor);
    	return res>=mo?res-mo:res;
    }
    int fa[N],sum[N],siz[N],hs[N];
    int son[N],ns,pre[N],suc[N];
    int g[N];
    void dp(int x){
    	sum[x]=sum[fa[x]]+a[x];
    	siz[x]=1;
    	for (EDGE *ei=last[x];ei;ei=ei->las)
    		if (ei->to!=fa[x]){
    			fa[ei->to]=x;
    			dp(ei->to);
    			siz[x]+=siz[ei->to];
    			if (siz[ei->to]>siz[hs[x]])
    				hs[x]=ei->to;
    		}
    	if (!hs[x]){
    		rt[x]=newnode();
    		add(rt[x],-INF,INF,sum[x],1);
    		g[x]=(a[x]>=0);
    		return;
    	}
    	son[ns=1]=hs[x];
    	pre[0]=1,pre[1]=g[hs[x]];
    	for (EDGE *ei=last[x];ei;ei=ei->las)
    		if (ei->to!=fa[x] && ei->to!=hs[x]){
    			son[++ns]=ei->to;
    			pre[ns]=(long long)pre[ns-1]*g[ei->to]%mo;
    		}
    	suc[ns+1]=1;
    	for (int i=ns;i>=1;--i)
    		suc[i]=(long long)suc[i+1]*g[son[i]]%mo;
    	rt[x]=rt[hs[x]];
    	for (int i=2;i<=ns;++i){
    		g[x]=(g[x]+(long long)calc(rt[son[i]],-INF,INF,rt[x],sum[fa[x]]*2+a[x])*suc[i+1])%mo;
    		rt[son[i]]->get_tag(pre[i-1]);
    		rt[x]->get_tag(g[son[i]]);
    		rt[x]=merge(rt[x],rt[son[i]]);
    	}
    	add(rt[x],-INF,INF,sum[x],pre[ns]);
    	g[x]+=query(rt[x],-INF,INF,sum[fa[x]],INF);
    	g[x]>=mo?g[x]-=mo:0;
    }
    int main(){
    	freopen("tree.in","r",stdin);
    	freopen("tree.out","w",stdout);
    	scanf("%d",&n);
    	for (int i=1;i<=n;++i)
    		scanf("%d",&a[i]);
    	for (int i=1;i<n;++i){
    		int u,v;
    		scanf("%d%d",&u,&v);
    		e[ne]={v,last[u]};
    		last[u]=e+ne++;
    		e[ne]={u,last[v]};
    		last[v]=e+ne++;
    	}
    	null=d;
    	*null={null,null,0,0};
    	dp(n>>1);
    	printf("%d
    ",g[n>>1]);
    	return 0;
    }	
    

    总结

    好多树形DP都可以用线段树合并来优化啊……

  • 相关阅读:
    python logging模块
    python re模块
    python xml模块
    python json,pickle,shelve模块
    python os,sys模块
    python 临时添加环境变量
    python random模块
    python time模块
    python 装饰器的简单使用
    python学习之路(二)
  • 原文地址:https://www.cnblogs.com/jz-597/p/11421276.html
Copyright © 2011-2022 走看看