四维DP了解一下???
好了首先一般不会先想到四维DP,一般都是想到二维DP了,所以我们先讲一个二维的dfs做法(尽管只有30pts)
T成这样子:
#include<bits/stdc++.h> #define ll long long using namespace std; inline ll read(){ ll ans=0; char last=' ',ch=getchar(); while(ch>'9'||ch<'0') last=ch,ch=getchar(); while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar(); if(last=='-') ans=-ans; return ans; } int n,m,maxn; int a[351],b[121],cnt[5]; int f[351][5]; void dfs(int num,int ans){ if(num==n){ if(maxn<ans) maxn=ans; return; } for(int i=1;i<=4;i++){ if(!cnt[i]) continue; cnt[i]--; dfs(num+i,ans+a[num+i]); cnt[i]++; } } int main(){ n=read();m=read(); for(int i=1;i<=n;i++) a[i]=read(); for(int i=1;i<=m;i++) b[i]=read(),cnt[b[i]]++; dfs(1,a[1]); cout<<maxn<<endl; return 0; }
思路的话不是很难理解,num表示当前到达的棋盘位置(一开始显然是1),ans表示当前分数的最大值,然后当num==n的时候,与当前记录的最大值比较,然后留下较大的那一个;
否则的话,选择1~4任一卡片(前提是有,用cnt数组维护)进行递归操作,最后莫得忘记回溯;
然后本想就着这份代码,改一改改成记忆化的,但是你会发现你无从下手,因为限制条件太多啦。
所以小手一伸,我们去借鉴题解
正解的记忆化是四维的,四个参数abcd分别表示走1,2,3,4步的牌剩余的数量;
然后开数组f[i][j][k][l]表示用i张1,j张2,k张3,l张4的最大的得分,然后如同DP一样,首先判断还有没有这个牌(即if(*)*代表任意一种牌),如果有,我们令这种牌对应的参数减一进行dfs,比较当前值与dfs值,取最大的;
#include<bits/stdc++.h> #define ll long long using namespace std; inline ll read(){ ll ans=0; char last=' ',ch=getchar(); while(ch>'9'||ch<'0') last=ch,ch=getchar(); while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar(); if(last=='-') ans=-ans; return ans; } int n,m; int s[351],x[121],cnt[5]; int f[41][41][41][41]; int dfs(int a,int b,int c,int d){ if(f[a][b][c][d]!=0) return f[a][b][c][d]; if(a)f[a][b][c][d]=max(f[a][b][c][d],dfs(a-1,b,c,d)); if(b)f[a][b][c][d]=max(f[a][b][c][d],dfs(a,b-1,c,d)); if(c)f[a][b][c][d]=max(f[a][b][c][d],dfs(a,b,c-1,d)); if(d)f[a][b][c][d]=max(f[a][b][c][d],dfs(a,b,c,d-1)); f[a][b][c][d]+=s[a+2*b+3*c+4*d+1]; return f[a][b][c][d]; } int main(){ n=read();m=read(); for(int i=1;i<=n;i++) s[i]=read(); for(int i=1;i<=m;i++) x[i]=read(),cnt[x[i]]++; cout<<dfs(cnt[1],cnt[2],cnt[3],cnt[4])<<endl; return 0; }
关于DP正解:
#include<bits/stdc++.h> using namespace std; inline int read(){ int ans=0; char last=' ',ch=getchar(); while(ch<'0'||ch>'9') last=ch,ch=getchar(); while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar(); if(last=='-') ans=-ans; return ans; } int n,m,a[351],b[121],cnt[5]; int f[41][41][41][41]; int main(){ n=read();m=read(); for(int i=1;i<=n;i++) a[i]=read(); for(int i=1;i<=m;i++) b[i]=read(),cnt[b[i]]++; f[0][0][0][0]=a[1]; for(int i=0;i<=cnt[1];i++) for(int j=0;j<=cnt[2];j++) for(int k=0;k<=cnt[3];k++) for(int l=0;l<=cnt[4];l++){ int now=1+i+j*2+k*3+l*4; if(i) f[i][j][k][l]=max(f[i][j][k][l],f[i-1][j][k][l]+a[now]); if(j) f[i][j][k][l]=max(f[i][j][k][l],f[i][j-1][k][l]+a[now]); if(k) f[i][j][k][l]=max(f[i][j][k][l],f[i][j][k-1][l]+a[now]); if(l) f[i][j][k][l]=max(f[i][j][k][l],f[i][j][k][l-1]+a[now]); } cout<<f[cnt[1]][cnt[2]][cnt[3]][cnt[4]]<<endl; return 0; }
凑活着看看得了;
end-