zoukankan      html  css  js  c++  java
  • 「CodePlus 2018 4 月赛」Tommy 的结合(斜率优化)

    题意:

    思路:

    n<=66 (dp[i][j])表示匹配到i,j的时候最大值,转移枚举下一个点,总复杂度(O(n^4))

    为了降低复杂度

    (dp[l][r]=max(dp[i][j]+c[l][r]-dis[i][l]^2-dis[j][r]^2))

    如果想只枚举(r) ,可以把(dp[i][j]-dis[i][l]^2)给处理出来,就多开一个数组

    (dp[l][r])表示恰好到 (G[l'][r])表示往后多走了一个(l)

    (dp[i][j]=max(G[i][k]-(dis[j-1]-dis[k])^2)+C[i][j])

    (G[i][j]=max(dp[l][j]-(dis[i-1]-dis[l])^2))

    这样复杂度就是(O(n^3))

    看这样的式子可以想到斜率优化(链上的时候)

    (dp[i][j]=max(G[i][k]-dis[k]^2+2*dis[j-1]*dis[k])+c[i][j]-dis[j-1]^2)

    dis是单调递增的

    (G[l]-dis_l^2+2*dis_{j-1}*dis_l>=G[r]-dis_r^2+2*dis_{j-1}*dis_r)

    (dis_{j-1}>=frac{G_r-G_l+dis_l^2-dis_r^2}{dis_l-dis_r})

    维护斜率递增的凸包,G同理

    上面可以解决链上的情况

    然后就是树上的斜率优化,每次从父亲节点继承(l,r),可以二分这时候加入的右端点,记录原来栈里的值用于还原。

    这样子总的复杂度就是(O(n^2logn))

    #include<bits/stdc++.h>
    #define M 2705
    #define ll long long
    using namespace std;
    void Rd(int &res) {
    	res=0;
    	char c;
    	int fl=1;
    	while(c=getchar(),c<48)if(c=='-')fl=-1;
    	do res=(res<<1)+(res<<3)+(c^48);
    	while(c=getchar(),c>=48);
    	res*=fl;
    }
    struct Node {
    	int tot,n,to[M],pr[M],la[M],fa[M],dep[M],a[M],L[M],R[M],Id[M],id,dis[M];
    	void add(int x,int y) {
    		to[++tot]=y,pr[tot]=la[x],la[x]=tot;
    	}
    	void dfs(int x,int f) {
    		dep[x]=dep[f]+a[x],L[x]=++id,Id[id]=x,dis[x]=dis[f]+a[x];
    		for(int i=la[x]; i; i=pr[i])if(to[i]!=f)dfs(to[i],x);
    		R[x]=id;
    	}
    } A[2];
    int C[M][M];
    struct P2 {
    	ll dis[2][M],dp[M][M],G[M][M];
    	struct node {
    		ll y,x;
    	} stk[M],stk2[M][M],od[M][M],Od[M];
    	int L1[M],R1[M],L2[M][M],R2[M][M];//前面针对dp的,后面针对G的
    	ll up(node l,node r) {//l在r的右面
    		return r.y-l.y+l.x*l.x-r.x*r.x;
    	}
    	ll down(node l,node r) {
    		return l.x-r.x;
    	}
    	ll calc(ll x) {
    		return 1ll*x*x;
    	}
    	void Dfs(int x,int f,int rt) {
    		int L,R,p;
    		L1[x]=L1[f],R1[x]=R1[f];
    		int &l=L1[x],&r=R1[x];
    		if(x!=1) {
    			p=l,L=l+1,R=r;
    			while(L<=R) {
    				int mid=(L+R)>>1;
    				if(2.0*dis[1][f]*down(stk[mid],stk[mid-1])>=1.0*up(stk[mid],stk[mid-1]))p=mid,L=mid+1;
    				else R=mid-1;
    			}
    			l=p,dp[rt][x]=stk[l].y-calc(stk[l].x-dis[1][f])+C[rt][x];
    		}
    		L=l,R=r-1,p=r;
    		node now=(node)<%G[rt][x],dis[1][x]%>;
    		while(L<=R) {
    			int mid=(L+R)>>1;
    			if(1.0*up(stk[mid+1],stk[mid])*down(now,stk[mid+1])>=1.0*up(now,stk[mid+1])*down(stk[mid+1],stk[mid]))p=mid,R=mid-1;
    			else L=mid+1;
    		}
    		r=p,Od[x]=stk[r+1],stk[++r]=now;
    		for(int y,i=A[1].la[x]; i; i=A[1].pr[i])if((y=A[1].to[i])!=f)Dfs(y,x,rt);
    		if(x!=1)stk[R1[x]]=Od[x];
    	}
    	void dfs(int x,int f) {
    		if(x!=1) {
    			for(int y,i=1; i<=A[1].n; i++) {
    				y=A[1].Id[i],L2[x][y]=L2[f][y],R2[x][y]=R2[f][y];
    				int &l=L2[x][y],&r=R2[x][y],L=l+1,R=r,p=l;
    				while(L<=R) {
    					int mid=(L+R)>>1;
    					if(2.0*dis[0][f]*down(stk2[y][mid],stk2[y][mid-1])>=1.0*up(stk2[y][mid],stk2[y][mid-1]))p=mid,L=mid+1;
    					else R=mid-1;
    				}
    				l=p,G[x][y]=stk2[y][p].y-calc(dis[0][f]-stk2[y][p].x);
    			}
    			L1[0]=0,R1[0]=-1,Dfs(1,0,x);
    		}
    		for(int i=1; i<=A[1].n; i++) {
    			int y=A[1].Id[i],&l=L2[x][y],&r=R2[x][y],L=l,R=r-1,p=r;
    			node now=(node)<%dp[x][y],dis[0][x]%>;
    			while(L<=R) {
    				int mid=(L+R)>>1;
    				if(1.0*up(stk2[y][mid+1],stk2[y][mid])*down(now,stk2[y][mid+1])>=1.0*up(now,stk2[y][mid+1])*down(stk2[y][mid+1],stk2[y][mid]))p=mid,R=mid-1;
    				else L=mid+1;
    			}
    			r=p,od[x][y]=stk2[y][r+1],stk2[y][++r]=now;
    		}
    		for(int y,i=A[0].la[x]; i; i=A[0].pr[i])if((y=A[0].to[i])!=f)dfs(y,x);
    		for(int y,i=1; i<=A[1].n; i++)y=A[1].Id[i],stk2[y][R2[x][y]]=od[x][y];
    	}
    	void solve() {
    		A[0].dfs(1,0),A[1].dfs(1,0);
    		for(int i=1; i<=A[0].n; i++)dis[0][i]=A[0].dis[i];
    		for(int i=1; i<=A[1].n; i++)dis[1][i]=A[1].dis[i];
    		memset(dp,-63,sizeof(dp)),memset(G,-63,sizeof(G)),memset(R2,-1,sizeof(R2)),memset(R1,-1,sizeof(R1));
    		dp[1][1]=G[1][1]=0,dfs(1,0);
    		ll ans=-1e18;
    		for(int i=1; i<=A[0].n; i++)for(int j=1; j<=A[1].n; j++)ans=max(ans,dp[i][j]);
    		printf("%lld
    ",ans);
    	}
    } p2;
    int main() {
    	Rd(A[0].n),Rd(A[1].n);
    	for(int i=2; i<=A[0].n; i++)Rd(A[0].a[i]);
    	for(int i=2; i<=A[1].n; i++)Rd(A[1].a[i]);
    	for(int i=2; i<=A[0].n; i++)Rd(A[0].fa[i]),A[0].add(A[0].fa[i],i);
    	for(int i=2; i<=A[1].n; i++)Rd(A[1].fa[i]),A[1].add(A[1].fa[i],i);
    	for(int i=2; i<=A[0].n; i++)for(int j=2; j<=A[1].n; j++)Rd(C[i][j]);
    	p2.solve();
    	return 0;
    }
    
  • 相关阅读:
    网页素材收集
    【转】你离顶尖 Java 程序员,只差这11本书的距离
    Jetbrains 破解 2017
    WebStorm的常用操作
    浅谈MySQL主从复制
    Lombok注解指南
    【我的《冒号课堂》学习笔记】设计模式(3)行为模式
    【我的《冒号课堂》学习笔记】设计模式(2)结构模式
    【我的《冒号课堂》学习笔记】设计模式(1)创建模式
    【我的《冒号课堂》学习笔记】设计原则(4)保变原则
  • 原文地址:https://www.cnblogs.com/cly1231/p/12976526.html
Copyright © 2011-2022 走看看