题面
(T) 组测试数据。有 (n) 个音节,每个音节 (h_iin[1,A]),还有 (m) 个限制 ((l_i,r_i,g_i)) 表示 (max_{k=l_i}^{r_i}h_k=g_i)。求满足条件的 (h_i) 的方案数膜 (998244353)。
数据范围:(1le Tle 20),(1le l_ile r_ile nle 9cdot 10^8),(1le g_ile Ale 9cdot 10^8),(1le mle 500)。
蒟蒻语
这是一眼题,蒟蒻 (7) 个小时就秒掉了。
题解做法时间复杂度 (Theta(Tmlog m))。
蒟蒻解
转化条件:(forall kin[l_i,r_i],h_kle g_i) 并且 (exists kin [l_i,r_i],h_kge g_i)。
可以离散化,所以下文中可以把 (n) 看成与 (m) 相等的范围。
很明显对于 (g_i<g_j) 如果一个区间满足了 (i) 条件就无法满足 (j) 条件了,所以可以按 (g) 从小到大排序条件,每次处理完以后把区间删除;但是又会发现对于 (g) 相等的区间情况很复杂,所以可以排序后对每个 (g) 分开处理覆盖部分答案求积。删除区间可以用并查集(参考),然后最后对于每个没删除的音节答案乘以 (A) 即可。
然后只需假设 (g) 一定,并且区间覆盖全部音节。这时有多个棘手的条件,根据条件的形式容易想到子集反演。
设 (p=(forall kin[l_i,r_i],h_kle g_i)),(q=(exists kin [l_i,r_i],h_kge g_i)),所以:
这个子集反演是可以用 ( m dp) 维护的,(f(i,j)) 表示 (forall kin[i,j),h_k<g_i) 前 (i) 个音节的方案数。
假设有个以 (i) 为左端点的条件 ([i,t)),可以用 ( m dp) 维护反演老套路打个负标记并用类似 ( m min-max) 卷积的方法:
操作比看起来简单:把 (jge t) 的 (f(i,j)) 直接置为 (0),然后 (f(i,t)leftarrowsum_{j<t} f(i,j)=sum f(i,j))。
然后是一波骚操作:把 (i) 这维滚掉(用离散化和并查集),然后线段树维护 (j) 这一维(啊!( m dp) 沦陷了!)。
最后还有个小细节:要把 (g) 相等但是 ([l_i,r_i]in[l_j,r_j]) 的 (j) 条件不放到 ( m dp) 里,最后删除区间时统计。
时间复杂度 (Theta(Tmlog m))。 蒟蒻的题解都写成这样了,有人看得懂就怪了。
代码
#include <bits/stdc++.h>
using namespace std;
//Start
typedef long long ll;
typedef double db;
#define mp(a,b) make_pair((a),(b))
#define x first
#define y second
#define be(a) (a).begin()
#define en(a) (a).end()
#define sz(a) int((a).size())
#define pb(a) push_back(a)
#define R(i,a,b) for(int i=(a),I=(b);i<I;i++)
#define L(i,a,b) for(int i=(b)-1,I=(a)-1;i>I;i--)
const int iinf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
//Data
const int Q=501,N=Q<<1,mod=998244353;
int Pow(int a,int x){
if(!a) return 0; int res=1;
for(;x;a=1ll*a*a%mod,x>>=1)if(x&1) res=1ll*a*res%mod;
return res;
}
int n,m,h,l[Q],r[Q],g[Q],dn,d[N],te[N+1];
int find(int u){return u==te[u]?u:(te[u]=find(te[u]));}
map<int,vector<int>> mp;
//SegmentTree
const int T=N<<2; int val[T],ml[T];
#define mid ((l+r)>>1)
void clear(int k=0,int l=0,int r=dn){
val[k]=0,ml[k]=1; if(r-l==1) return;
clear(k*2+1,l,mid),clear(k*2+2,mid,r);
}
void pushup(int k){(val[k]=val[k*2+1]+val[k*2+2])%=mod;}
void pushmul(int k,int v){val[k]=1ll*val[k]*v%mod,ml[k]=1ll*ml[k]*v%mod;}
void pushdown(int k){if(ml[k]^1) pushmul(k*2+1,ml[k]),pushmul(k*2+2,ml[k]),ml[k]=1;}
void add(int x,int v,int k=0,int l=0,int r=dn){
if(r<=x||x+1<=l) return; if(r-l==1) return void((val[k]+=v)%=mod);
pushdown(k),add(x,v,k*2+1,l,mid),add(x,v,k*2+2,mid,r),pushup(k);
}
void mul(int x,int y,int v,int k=0,int l=0,int r=dn){
if(r<=x||y<=l) return; if(x<=l&&r<=y) return pushmul(k,v);
pushdown(k),mul(x,y,v,k*2+1,l,mid),mul(x,y,v,k*2+2,mid,r),pushup(k);
}
//DP
int dp(int x,vector<int> a){
mul(0,dn,0),add(0,1); vector<int> b;
for(int i:a) l[i]=find(l[i]),r[i]=find(r[i]);
sort(be(a),en(a),[&](int p,int q){return l[p]!=l[q]?l[p]<l[q]:r[p]<r[q];});
for(int i:a){while(sz(b)&&r[b.back()]>=r[i]) b.pop_back();b.pb(i);}
R(t,0,sz(b)){
int i=find(l[b[t]]),j=find(r[b[t]]); if(i>=j) return 0;
int ni=j;if(t+1<sz(b)) ni=min(ni,find(l[b[t+1]]));
mul(j,dn,0),add(j,(mod-val[0])%mod);
for(int k=i;k<ni;k=te[k]=find(k+1))
mul(0,k+1,Pow(x,d[k+1]-d[k])),mul(k+1,dn,Pow(x-1,d[k+1]-d[k]));
}
int res=val[0];
for(int i:a) for(int k=find(l[i]);k<find(r[i]);k=te[k]=find(k+1)) res=1ll*res*Pow(x,d[k+1]-d[k])%mod;
return res;
}
//Main
int Main(){
cin>>n>>m>>h,dn=0,d[dn++]=0,d[dn++]=n,mp.clear();
R(i,0,m) cin>>l[i]>>r[i]>>g[i],--l[i],d[dn++]=l[i],d[dn++]=r[i];
sort(d,d+dn),dn=unique(d,d+dn)-d,clear();
R(i,0,m) l[i]=lower_bound(d,d+dn,l[i])-d,r[i]=lower_bound(d,d+dn,r[i])-d,mp[g[i]].pb(i);
iota(te,te+dn+1,0); int res=1;
for(auto u:mp) res=1ll*res*dp(u.x,u.y)%mod;
for(int k=find(0);k<dn-1;k=te[k]=find(k+1)) res=1ll*res*Pow(h,d[k+1]-d[k])%mod;
return res;
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0),cout.tie(0);
int ta; cin>>ta; while(ta--) cout<<Main()<<'
';
return 0;
}
祝大家学习愉快!