题面
给定 (n) 个字符串,从中选出若干个按给出顺序连接起来,总长等于 (m),求字典序最小的,保证有解。
数据范围:(1le nle 2000),(1le kle 10^4),字符串总长 (Sle 10^6)。
题解
atcoder
的题就是好啊,非常巧妙,毫不毒瘤。
这篇题解要抨击 Z-function
怪 ycx
,造福人民。
有一个非常显然的思路:每次选当前选了的串右边没选且能选的中最小的。
这里的能选的可以用个后缀背包判断,可以用 bitset
维护。
注意到对于每个长度 (i) 能不能由后缀拼成有个分界点,设为 (r_i)。
这个做法有个明显的问题:如果有一个串 geo
和一个串 geoafo
,怎么知道哪个好呢?
如果选了
geo
,如果剩下的串只有boomzero
,那么就不如geoafo
了。如果选了
geoafo
,如果剩下的串有个aba
,那么就geoaba
更优了。
所以想到用 dp
,来避免不同长度的字符串的比较。
设 (f_i) 表示拼成的长度为 (i) 且剩下字符串可以拼成 (m-i) 的字符串中字典序最小的(注意 (f_i) 是个字符串!)。
然后 (g_i) 是拼成 (f_i) 的最右字符串的编号,要最左。
转移就是从小到大枚举每个 (i)(前提是有这样的 (f_i)),然后枚举加上的字符串长度 (j)。
要找到 (g_i) 到 (r_{m-i-j}) 中长度为 (j) 的字典序最小字符串,这个可以用 st
表预处理 (Theta(1)) 完成。
然后把这个串加到 (f_i) 后面暴力和 (f_{i+j}) 比较即可。
关于时间复杂度,总感觉是 (Theta(n^2)) 的,但是不会证,就当 (Theta(n^3)) 的吧(可是它明明跑得很快 /kk
)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef double db;
#define x first
#define y second
#define bg begin()
#define ed end()
#define pb push_back
#define mp make_pair
#define sz(a) int((a).size())
#define R(i,n) for(int i(0);i<(n);++i)
#define L(i,n) for(int i((n)-1);i>=0;--i)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
//Data
const int N=2000+1,M=1e4+1,D=12;
int n,m,o[N],we[N],ri[M],g[M];
vector<int> id[M],st[M][D];
string s[N],f[M];
bitset<M> h[N];
//Functions
int get(int j,int l,int r){
l=lower_bound(id[j].bg,id[j].ed,l)-id[j].bg;
r=lower_bound(id[j].bg,id[j].ed,r)-id[j].bg;
if(r-l<=0) return -1;
int t=log2(r-l),&a=st[j][t][l],&b=st[j][t][r-(1<<t)];
return we[a]>we[b]?b:a;
}
//Main
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
cin>>n>>m,h[n].set(0);
R(i,n) cin>>s[o[i]=i];
stable_sort(o,o+n,[&](int i,int j){return s[i]<s[j];});
R(i,n) we[o[i]]=i,id[sz(s[i])].pb(i);
R(i,m+1){
g[i]=-1,ri[i]=0;
vector<int> &c=id[i],*a=st[i];
if(!sz(c)) continue;
R(i,D) a[i].resize(sz(c));
R(i,sz(c)) a[0][i]=c[i];
R(t,D-1)R(i,sz(c)+1-(2<<t)){
int &x=a[t][i],&y=a[t][i+(1<<t)];
a[t+1][i]=we[x]>we[y]?y:x;
}
}
L(i,n){
h[i]=h[i+1]|h[i+1]<<sz(s[i]);
R(j,m+1)if(h[i][j]&&!h[i+1][j]) ri[j]=i;
}
f[0]="",g[0]=0,ri[0]=n;
R(j,m)if(~g[j])L(i,m+1-j){
int t=get(i,g[j],ri[m-j-i]);
// cout<<"j="<<j<<"->j+i="<<j+i<<'
';
// cout<<"get("<<i<<","<<g[j]<<","<<ri[m-j-i]<<")="<<t<<'
';
if(~t&&(!~g[j+i]||f[j+i]>f[j]+s[t]))
f[j+i]=f[j]+s[t],g[j+i]=t+1;
}
// R(i,m+1) cout<<g[i]<<" "<<f[i]<<'
';
cout<<f[m]<<'
';
return 0;
}
祝大家学习愉快!