【BZOJ2905】背单词
Description
给定一张包含N个单词的表,每个单词有个价值W。要求从中选出一个子序列使得其中的每个单词是后一个单词的子串,最大化子序列中W的和。
Input
第一行一个整数TEST,表示数据组数。
接下来TEST组数据,每组数据第一行为一个整数N。
接下来N行,每行为一个字符串和一个整数W。
Output
TEST行,每行一个整数,表示W的和的最大值。
数据规模
设字符串的总长度为Len
30.的数据满足,TEST≤5,N≤500,Len≤10^4
100.的数据满足,TEST≤10,N≤20000,Len≤3*10^5
题解:先建出AC自动机,然后A串是B串的子串当且仅当B中某个节点沿着fail树往根走,能走到A的结束节点。那么我们先将权值<=0的串都扔掉,然后从前往后枚举每个字符串,对于每个串,我们查询它的每个节点到根路径上的所有节点的DP值的最大值,然后用最大值+当前串的价值得到当前点的DP值,最后将当前串的DP值存到当前串的结束节点位置上。
查询最大值的时候可以将链查询,点修改变成点查询,子树修改,然后用线段树维护即可。
#include <cstdio> #include <cstring> #include <iostream> #include <queue> #include <vector> #define lson x<<1 #define rson x<<1|1 using namespace std; const int maxn=20010; const int maxm=300010; typedef long long ll; struct node { int ch[26],fail; }p[maxm]; queue<int> q; int T,n,tot,cnt; ll ans,sum; int v[maxn],to[maxm],next[maxm],head[maxm],p1[maxm],p2[maxm]; char str[maxm]; vector<int> pos[maxn]; ll s[maxm<<2],tag[maxm<<2]; inline void build() { register int i,u; q.push(1); while(!q.empty()) { u=q.front(),q.pop(); for(i=0;i<26;i++) { if(!p[u].ch[i]) { if(u==1) p[u].ch[i]=1; else p[u].ch[i]=p[p[u].fail].ch[i]; continue; } q.push(p[u].ch[i]); if(u==1) p[p[u].ch[i]].fail=1; else p[p[u].ch[i]].fail=p[p[u].fail].ch[i]; } } } inline void add(int a,int b) { to[cnt]=b,next[cnt]=head[a],head[a]=cnt++; } void dfs(int x) { p1[x]=++p2[0]; for(int i=head[x];i!=-1;i=next[i]) dfs(to[i]); p2[x]=p2[0]; } inline void pushdown(int x) { if(tag[x]) { s[lson]=max(s[lson],tag[x]),s[rson]=max(s[rson],tag[x]); tag[lson]=max(tag[lson],tag[x]),tag[rson]=max(tag[rson],tag[x]); tag[x]=0; } } void updata(int l,int r,int x,int a,int b,ll c) { if(a<=l&&r<=b) { s[x]=max(s[x],c),tag[x]=max(tag[x],c); return ; } pushdown(x); int mid=(l+r)>>1; if(a<=mid) updata(l,mid,lson,a,b,c); if(b>mid) updata(mid+1,r,rson,a,b,c); s[x]=max(s[lson],s[rson]); } ll query(int l,int r,int x,int a) { if(l==r) return s[x]; pushdown(x); int mid=(l+r)>>1; if(a<=mid) return query(l,mid,lson,a); return query(mid+1,r,rson,a); } inline void work() { register int i,j,a,b,u; memset(s,0,sizeof(s[0])*4*(tot+1)),memset(tag,0,sizeof(tag[0])*4*(tot+1)),memset(p,0,sizeof(p[0])*(tot+1)); scanf("%d",&n); tot=1,cnt=p2[0]=0,ans=0; for(i=1;i<=n;i++) { scanf("%s%d",str,&v[i]),a=strlen(str); if(v[i]<=0) continue; pos[i].clear(); for(u=1,j=0;j<a;j++) { b=str[j]-'a'; if(!p[u].ch[b]) p[u].ch[b]=++tot; u=p[u].ch[b],pos[i].push_back(u); } } build(); memset(head,-1,sizeof(head[0])*(tot+1)); for(i=2;i<=tot;i++) add(p[i].fail,i); dfs(1); for(i=1;i<=n;i++) if(v[i]>0) { for(a=pos[i].size(),sum=0,j=0;j<a;j++) sum=max(sum,query(1,tot,1,p1[pos[i][j]])); sum+=v[i],ans=max(ans,sum); updata(1,tot,1,p1[pos[i][a-1]],p2[pos[i][a-1]],sum); } printf("%lld ",ans); } int main() { //freopen("data.in","r",stdin); //freopen("data.out","w",stdout); scanf("%d",&T); while(T--) work(); return 0; }//1 5 a 1 ab 1 ac 4 abc 2 aa 1