NOI.AC #111. 运气大战 (动态dp)
对于两个权值排序之后,可以主观意会每次选的数对保证\(|i-j|\leq 2\)
网上这个都有很多,我写这篇题解只是想吐槽一下他们毫无可读性的代码以及令人完全感受不到正确性的转移。。。
对于线段树上的每个节点\([l,r]\)存储一个矩阵\(a[2][2]\),表示已经解决了\([l-i,r-j]\)这段区间的数对
每次合并儿子的时候,是\([l,mid],[mid+1,r]\)两个区间的合并,即合并了\([l-i,mid-j],[mid+1-j,r-k]\),区间是相邻重合的,枚举\(i,j,k\)之后,就是矩阵乘法
预处理阶段,每次的区间是\([x-i,x-j]\),长度不超过3,直接暴力枚举
可以看到每次更新都会影响到\([x,x+2]\)这段区间
特殊的是,当前点只有一个点时,我们需要考虑取的是空集的边界条件
#include<bits/stdc++.h>
using namespace std;
#define reg register
typedef long long ll;
#define rep(i,a,b) for(int i=a,i##end=b;i<=i##end;++i)
#define drep(i,a,b) for(int i=a,i##end=b;i>=i##end;--i)
#define pb push_back
template <class T> inline void cmin(T &a,T b){ ((a>b)&&(a=b)); }
template <class T> inline void cmax(T &a,T b){ ((a<b)&&(a=b)); }
char IO;
template<class T=int> T rd(){
T s=0;
int f=0;
while(!isdigit(IO=getchar())) if(IO=='-') f=1;
do s=(s<<1)+(s<<3)+(IO^'0');
while(isdigit(IO=getchar()));
return f?-s:s;
}
const int N=3e4+10;
const ll INF=1e18;
int n,m;
int a[N],b[N],pa[N],pb[N],pos[N];
int p[N];
struct Mat{
ll a[3][3]; // [l-i,r-j]
void clear(){ memset(a,-63,sizeof a); }
Mat operator * (const Mat x) {
Mat res; res.clear();
rep(i,0,2) rep(j,0,2) rep(k,0,2) cmax(res.a[i][k],a[i][j]+x.a[j][k]); // 朴素的矩阵转移
return res;
}
#define F(x,y) (pa[x]==pb[y]?-INF:1ll*::a[pa[x]]*b[pb[y]]) // 注意有非法的情况要赋为-INF
inline ll Calc(int l,int r){ // 对于区间l,r暴力预处理,r-l+1<=3
if(l==r) return F(l,r);
if(r-l==1) return max(F(l,l)+F(r,r),F(l,r)+F(r,l));
ll res=-1e18,T[3][3];
rep(i,l,r) rep(j,l,r) T[i-l][j-l]=F(i,j);
if(T[0][0]>=0) cmax(res,max(T[0][0]+T[1][1]+T[2][2],T[0][0]+T[1][2]+T[2][1]));
if(T[0][1]>=0) cmax(res,max(T[0][1]+T[1][0]+T[2][2],T[0][1]+T[1][2]+T[2][0]));
if(T[0][2]>=0) cmax(res,max(T[0][2]+T[1][0]+T[2][1],T[0][2]+T[1][1]+T[2][0]));
// 手动枚举6种情况
return res;
}
void Init(int x){ // 单点x的初始状态
clear();
rep(i,0,min(2,x-1)) {
rep(j,0,i) a[i][j]=Calc(x-i,x-j); // 枚举[l-i,r-j]
if(i<2) a[i][i+1]=0;
}
}
}s[N<<2];
void Build(int p,int l,int r) {
if(l==r) return s[p].Init(l);
int mid=(l+r)>>1;
Build(p<<1,l,mid),Build(p<<1|1,mid+1,r);
s[p]=s[p<<1]*s[p<<1|1];
}
void Upd(int p,int l,int r,int ql,int qr){
if(l==r) return s[p].Init(l);
int mid=(l+r)>>1;
if(ql<=mid) Upd(p<<1,l,mid,ql,qr);
if(qr>mid) Upd(p<<1|1,mid+1,r,ql,qr);
s[p]=s[p<<1]*s[p<<1|1];
}
int main(){
n=rd(),m=rd();
rep(i,1,n) a[i]=rd(),pa[i]=i;
rep(i,1,n) b[i]=rd(),pb[i]=i;
sort(pa+1,pa+n+1,[&](int x,int y){ return a[x]<a[y]; });
sort(pb+1,pb+n+1,[&](int x,int y){ return b[x]<b[y]; });
rep(i,1,n) pos[pb[i]]=i; // 预处理排序之后的序列
Build(1,1,n);
rep(i,1,m){
int x=rd(),y=rd();
swap(b[x],b[y]);
swap(pos[x],pos[y]);
x=pos[x],y=pos[y];
swap(pb[x],pb[y]);
Upd(1,1,n,x,min(x+2,n));
Upd(1,1,n,y,min(y+2,n)); // 注意每次更新3个
printf("%lld\n",s[1].a[0][0]);
}
}