zoukankan      html  css  js  c++  java
  • USACO2020JAN Platinum T2 官方正解的翻译?

    USACO2020JAN Platinum T2 官方正解的翻译?

    感觉这道题一时半会儿讲不清楚

    所以把文章翻译一遍,然后以注解的形式把讲的不太清楚的地方重新理解以下。

    (Benjamin Qi的分析)

    以下令(MOD=10^9+7)

    一些常数优化技巧:

    • (MOD)定义为常量。
    • 在加或减两个int时避免使用%运算。
    • 对于下文中提到的举证,使用固定大小的二维数组储存(而非c++中的vector套vector)。
    • 不必遍历矩阵中等于零的项(即后文中矩阵的下对角线部分)。

    定义一个关于模运算的独立的类(或struct)也会帮助卡常。

    为了方便,我们认为值域范围是([0,N))而非([1,N])

    同时后文会用到前文用到过的变量

    Subtask 1:

    我们可以在(O(NK^2))的复杂度内计算每一对((L,R)(1le Lle Rle N))的答案。

    具体而言,我们枚举左端点(L),然后先让(R=L)再不断的递增(R)

    (tot_i)表示以数字(i(iin[0,K) ))结尾的不降子序列有多少个。(将空子序列看做以零结尾)

    然后逐个往后添加数字更新(tot)数组。

    这样我们就可以(O(1))回答每个查询了。

    线段树 (subtasks 2,3):

    考虑到往区间([L,R])后添加数字(x)更新(tot)数组的过程其实是一个矩阵乘法(tot imes M_x)的过程。

    例如当(K=5)时向(tot)后添加一个(3)的过程相当于:

    [M3=egin{bmatrix}1&0&0&1&0\0&1&0&1&0\0&0&1&1&0\0&0&0&2&0\0&0&0&0&1end{bmatrix} ]

    满足

    [[c_0 c_1 c_2 c_3 c_4]cdot M_3=[c_0 c_1 c_2 c_0+c_1+c_2+2c_3 c_4] ]

    换句话说往后添加一个3相当于把(c_3)增加了(c_0+c_1+c_2+c_3)

    这启发我们用线段树来处理区间的矩阵乘法。

    若一个节点表示([L,R]),那它就应该表示矩阵(M=M_{A_L} imescdots imes M_{A_R})

    这样我们就可以(O(NK^3))建树,单次查询(O(K^3log N))的解决掉了。

    总复杂度是(O((N+Qlog N)K^3))的,大概率过不掉Subtask2。

    我们可以优化建树与查询的过程。

    • 对于查询,我们只需要知道矩阵乘积的第一行的结果。所以我们只需要处理(1 imes K)的矩阵和(K imes K)的矩阵相乘,这样单次查询复杂度就变成了(O(K^2log N)),这样就可以过掉subtask2了。

    • 对于建树,我们将每(K)个矩阵捆绑起来当成一个矩阵(类似分块?)然后对捆绑后的矩阵我们建一棵线段树。然后对于每次查询我们划分成捆绑后的矩阵乘积和零碎的数字。对于零碎的数字我们仿照Subtask1一样暴力(O(K))加入一个数。这样查询的复杂度也不会受影响。这样总复杂度就是(O(dfrac{N}{K}K^3+QK^2logN))

    优化后可以过subtask3,甚至可以拿到全部分数。

    分治(满分做法)

    线段树做法也可以修改,但是实在没有必要使用线段树。

    事实上,给定(b_1,b_2,ldots,b_N)和一个(O(1))的运算(oplus)(Q)次查询(igopluslimits_{i=L}^Rb_i),我们可以将询问离线下来,均摊(O(1))的回答。

    具体而言,我们每次把区间([L,R])划分为([L,M])([M+1,R]),预处理出所有([i,M])([M+1,i])区间的答案,然后对于跨过(M)的区间([l,r]),我们将([l,M])([M+1,r])合并即可。对于剩下不跨过(M)的区间,我们递归处理即可。

    在这道题中,我们可以处理出同时包含(M)(M+1)的区间。我们假设子序列的下标为(j_1<j_2<ldots<j_ale M<j_{a+1}<ldots<j_x)。我们枚举(A_{j_a})的值,生成所有合法的([i,M])([M+1,i])区间的答案。这样的复杂度是(O(NK^2))的。对于一个([l,r])的询问,我们用(O(K))的时间合并([l,M])([M+1,r])即可。剩下的询问递归下去处理即可。

    <注1>补充一下,具体生成所有合法的([i,M]​)([M+1,i]​)子序列个数是这样的:

    1. 对于左半区间([L,M]),(lans[i][j])表示在([i,M])中有多少个以数字(j)结尾的不降子序列。

      如官方正解中的代码段

      void countLeft(int a,int b)
      {
      	for(int i=a;i<=b;i++)
      		for(int k=1;k<=K;k++)
      			lans[i][k] = 0;
      	for(int k=K;k>=1;k--)
      	{
      		for(int j=k;j<=K;j++)
      			cnt[j] = 0;
      		for(int i=b;i>=a;i--)
      		{
      			if(A[i] == k)
      			{
      				cnt[k] = msum(2*cnt[k] + 1);
      				for(int j=k+1;j<=K;j++)
      					cnt[j] = msum(msum(2*cnt[j]) + lans[i][j]);
      			}
      			for(int j=k;j<=K;j++)
      				lans[i][j] = msum(lans[i][j] + cnt[j]);
      		}
      	}
      }
      

      我们先枚举(k),表示我们当前考虑以([k,K])结尾的不降子序列。然后我们按从(M)(L)的顺序枚举当前位置,(cnt[j])表示当前位置到(M)中 有多少个子序列 以(k)开头且以(j)结尾。显然(lans[i][j])就是枚举不同(k)时在位置(i)(cnt[j])的累加。

    2. 考虑右半区间时同理,(rans[i][j])表示([M+1,i])中有多少个以j开头的非降子序列,不过考虑到我们枚举的是(A_{j_a})的值,我们把(rans[i][j])做一个前缀和,让其表示([M+1,i])中有多少个以大于j的数开头的非降子序列。

    </注1>

    官方正解

    #include <iostream>
    #include <algorithm>
    #include <vector>
    using namespace std;
    #define MAXN 200000
    #define MAXQ 200000
    #define MOD 1000000007
    
    int msum(int a)
    {
      if(a >= MOD) return a-MOD;
      return a;
    }
    
    
    int N,K,Q;
    int A[MAXN];
    int l[MAXQ], r[MAXQ];
    int qid[MAXQ];
    int qans[MAXQ];
    
    int lans[MAXN][21];
    int rans[MAXN][21];
    int cnt[21];
    
    void countLeft(int a,int b)
    {
      for(int i=a;i<=b;i++)
      	for(int k=1;k<=K;k++)
      		lans[i][k] = 0;
      for(int k=K;k>=1;k--)
      {
      	for(int j=k;j<=K;j++)
      		cnt[j] = 0;
      	for(int i=b;i>=a;i--)
      	{
      		if(A[i] == k)
      		{
      			cnt[k] = msum(2*cnt[k] + 1);
      			for(int j=k+1;j<=K;j++)
      				cnt[j] = msum(msum(2*cnt[j]) + lans[i][j]);
      		}
      		for(int j=k;j<=K;j++)
      			lans[i][j] = msum(lans[i][j] + cnt[j]);
      	}
      }
    }
    
    void countRight(int a,int b)
    {
      for(int i=a;i<=b;i++)
      	for(int k=1;k<=K;k++)
      		rans[i][k] = 0;
      for(int k=1;k<=K;k++)
      {
      	for(int j=1;j<=k;j++)
      		cnt[j] = 0;
      	for(int i=a;i<=b;i++)
      	{
      		if(A[i] == k)
      		{
      			cnt[k] = msum(2*cnt[k] + 1);
      			for(int j=1;j<k;j++)
      				cnt[j] = msum(msum(2*cnt[j]) + rans[i][j]);
      		}
      		for(int j=1;j<=k;j++)
      			rans[i][j] = msum(rans[i][j] + cnt[j]);
      	}
      }
    }
    
    int split(int qa,int qb, int m)
    {
      int i = qa;
      int j = qb;
      while(i<j)
      {
      	if(r[qid[i]] > m && r[qid[j]] <= m)
      	{
      		swap(qid[i],qid[j]);
      		i++, j--;
      	}
      	else if(r[qid[i]] > m)
      		j--;
      	else if(r[qid[j]] <= m)
      		i++;
      	else
      		i++, j--;
      }
      if(i > j) return j;
      else if(r[qid[i]] <= m) return i;
      else return i-1;
    }
    
    void solve(int a,int b,int qa,int qb)
    {
      if(a>b || qa>qb) return;
      if(a == b)
      {
      	for(int i=qa;i<=qb;i++)
      		qans[qid[i]] = 1;
      	return;
      }
      int m = (a+b)/2;
      countLeft(a,m);
      countRight(m+1,b);
      for(int i=m+1;i<=b;i++)
      	for(int k=K-1;k>=1;k--)
      		rans[i][k] = msum(rans[i][k] + rans[i][k+1]);
      int qDone = 0;
      for(int i=qa;i<=qb;i++)
      {
      	int q = qid[i];
      	if(r[q] > m && l[q] <= m)
      	{
      		qans[q] = 0;
      		for(int k=1;k<=K;k++)
      			qans[q] = msum(qans[q] + (lans[l[q]][k]*((long long)rans[r[q]][k]))%MOD);
      		for(int k=1;k<=K;k++)
      			qans[q] = msum(qans[q] + lans[l[q]][k]);
      		qans[q] = msum(qans[q] + rans[r[q]][1]);
      		qDone++;
      	}
      	else if(qDone>0)
      		qid[i-qDone] = qid[i];
      }
      qb -= qDone;
      int qm = split(qa,qb,m);
      solve(a,m,qa,qm);
      solve(m+1,b,qm+1,qb);
    }
    
    int main()
    {
      freopen("nondec.in","r",stdin);
      freopen("nondec.out","w",stdout);
      cin >> N >> K;
      for(int i=0;i<N;i++)
      	cin >> A[i];
      cin >> Q;
      for(int i=0;i<Q;i++)
      {
      	cin >> l[i] >> r[i];
      	l[i]--,r[i]--;
      	qid[i] = i;
      }
      solve(0,N-1,0,Q-1);
      for(int i=0;i<Q;i++)
      	cout << qans[i]+1 << '
    ';
    }
    

    我的代码:

    #include<bits/stdc++.h>
    #define mp make_pair
    #define pb push_back
    #define fi first
    #define se second
    
    #define y0 pmt
    #define y1 pmtpmt
    #define x0 pmtQAQ
    #define x1 pmtQwQ
    
    using namespace std;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef vector<int > vi;
    typedef pair<int ,int > pii;
    const int INF=0x3f3f3f3f, maxn=200007;
    const int MOD=1e9+7;
    const ll LINF=0x3f3f3f3f3f3f3f3fLL;
    const ll P=19260817;
    char nc(){
        static char buf[100000],*p1=buf,*p2=buf;
        return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
    }
    inline ll read(){
        ll x=0,f=1;char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
        while(ch>='0'&&ch<='9')x=((x<<3)+(x<<1)+ch-'0')%MOD,ch=getchar();
        return x*f;
    }
    void write(int x){
        if(!x)putchar('0');if(x<0)x=-x,putchar('-');
        static int sta[20];register int tot=0;
        while(x)sta[tot++]=x%10,x/=10;
        while(tot)putchar(sta[--tot]+48);
    }
    int n,K,Q;
    struct Query{
        int l,r,id;
    }q[maxn],q1[maxn],q2[maxn];
    // int len1,len2;
    int a[maxn];
    int cnt[27];
    int lans[maxn][27],rans[maxn][27];
    int ans[maxn];
    int mod1(int a){return a>MOD?a-MOD:a;}
    void left(int l,int r){
        for(int i=l;i<=r;i++)
            for(int j=1;j<=K;j++)lans[i][j]=0;
        for(int k=K;k>=1;k--){
            for(int j=k;j<=K;j++)cnt[j]=0;
            for(int i=r;i>=l;i--){
                if(a[i]==k){
                    cnt[k]=mod1(mod1(cnt[k]*2)+1);
                    for(int j=k+1;j<=K;j++)cnt[j]=mod1(mod1(cnt[j]*2)+lans[i][j]);
                }
                for(int j=k;j<=K;j++)lans[i][j]=mod1(lans[i][j]+cnt[j]);
            }
        }
    }
    void right(int l,int r){
        for(int i=l;i<=r;i++)
            for(int j=1;j<=K;j++)rans[i][j]=0;
        for(int k=1;k<=K;k++){
            for(int j=1;j<=k;j++)cnt[j]=0;
            for(int i=l;i<=r;i++){
                if(a[i]==k){
                    cnt[k]=mod1(mod1(cnt[k]<<1)+1);
                    for(int j=1;j<k;j++)cnt[j]=mod1(mod1(cnt[j]<<1)+rans[i][j]);
                }
                for(int j=1;j<=k;j++)rans[i][j]=mod1(rans[i][j]+cnt[j]);
            }
        }
        for(int i=l;i<=r;i++)
            for(int k=K-1;k>=1;k--)
                rans[i][k]=mod1(rans[i][k]+rans[i][k+1]);
    }
    void solve(int l,int r,int ql,int qr){
        if(l>r||ql>qr)return ;
        if(l==r){
            for(int i=ql;i<=qr;i++){
                ans[q[i].id]=1;
            }
            return ;
        }
        int mid=(l+r)>>1;
        left(l,mid);right(mid+1,r);
        int len1=0,len2=0;
        for(int i=ql;i<=qr;i++){
            if(q[i].r<=mid)q1[++len1]=q[i];
            else if(q[i].l>mid)q2[++len2]=q[i];
            else {
                for(int k=1;k<=K;k++)ans[q[i].id]=mod1(ans[q[i].id]+1ll*lans[q[i].l][k]*rans[q[i].r][k]%MOD);
                for(int k=1;k<=K;k++)ans[q[i].id]=mod1(ans[q[i].id]+lans[q[i].l][k]);
                ans[q[i].id]=mod1(ans[q[i].id]+rans[q[i].r][1]);
            }
        }
        for(int i=ql;i<ql+len1;i++)q[i]=q1[i-ql+1];
        for(int i=ql+len1;i<ql+len1+len2;i++)q[i]=q2[i-ql-len1+1];
        solve(l,mid,ql,ql+len1-1);
        solve(mid+1,r,ql+len1,ql+len1+len2-1);
    
    }
    int main(){
        scanf("%d%d",&n,&K);
        for(int i=1;i<=n;i++)scanf("%d",a+i);
        scanf("%d",&Q);
        for(int i=1;i<=Q;i++)
            scanf("%d%d",&q[i].l,&q[i].r),q[i].id=i;
        solve(1,n,1,Q);
        for(int i=1;i<=Q;i++)printf("%d
    ",mod1(ans[i]+1));
    	return 0;
    }
    
    

    复杂度是(O(NK^2log N+Q(K+log N)))

    矩阵逆(满分做法)

    (ipref[x]=M_{A_{x-1}}^{-1}cdot M_{A_{x-2}}^{-1}cdots M_{A_1}^{-1}),以及(pref[x]=M_{A_1}cdot M_{A_2}cdots M_{A_{x-1}}),计算([M_x^{-1}])其实相当容易。

    例如当(K=5)时,

    [M_3^{-1}=egin{bmatrix} 1 & 0 & 0 & -1/2 & 0 \ 0 & 1 & 0 & -1/2 & 0 \ 0 & 0 & 1 & -1/2 & 0 \ 0 & 0 & 0 & 1/2 & 0 \ 0 & 0 & 0 & 0 & 1 \ end{bmatrix}, ]

    满足

    [egin{bmatrix} c_0 & c_1 & c_2 & c_0+c_1+c_2+2c_3 & c_4 end{bmatrix}cdot M_3^{-1}= egin{bmatrix} c_0 & c_1 & c_2 & c_3 & c_4 end{bmatrix}. ]

    从而又因为(M_x)以及(M_x^{-1})和单位矩阵很像,我们可以用乘法分配率在(O(NK^2))的时间内处理出(pref,ipref)

    那么([L,R])的答案(A_L,ldots,A_R)的不降子序列个数其实就是(sumlimits_{i=0}^{K-1}(ipref[L-1]cdot pref[R])[0][i]),单次复杂度是(O(K^2))的。

    考虑到我们只需要乘积矩阵的第一行,答案可以改写为(sumlimits_{i=0}^{K-1}(ipref[L-1][0]cdot pref[R])[i]\ =sum_{i=0}^{K-1}(ipref[L-1][0][i]cdot left(sum_{j=0}^{K-1}pref[R][i][j] ight)).)

    ((sum_{j=0}^{K-1}pref[R][i][j]))显然可以预处理,然后就没了。

    复杂度是(O(NK^2+QK))的。

    官方sol:

    #include <bits/stdc++.h>
    using namespace std;
     
    typedef long long ll;
    const int MOD = 1e9+7; // 998244353; // = (119<<23)+1
    const int MX = 5e4+5; 
    
    void setIO(string name) {
    	ios_base::sync_with_stdio(0); cin.tie(0);
    	freopen((name+".in").c_str(),"r",stdin);
    	freopen((name+".out").c_str(),"w",stdout);
    }
     
    struct mi {
    	int v; explicit operator int() const { return v; }
    	mi(ll _v) : v(_v%MOD) { v += (v<0)*MOD; }
    	mi() : mi(0) {}
    };
    mi operator+(mi a, mi b) { return mi(a.v+b.v); }
    mi operator-(mi a, mi b) { return mi(a.v-b.v); }
    mi operator*(mi a, mi b) { return mi((ll)a.v*b.v); }
    typedef array<array<mi,20>,20> T;
     
    int N,K,Q;
    vector<int> A;
    array<mi,20> sto[MX], isto[MX];
    mi i2 = (MOD+1)/2;
     
    void prin(T& t) { // print a matrix for debug purposes
    	for (int i = 0; i < K; ++i) {
    		for (int j = 0; j < K; ++j) 
    			cout << t[i][j].v << ' ';
    		cout << "
    ";
    	}
    	cout << "-------
    ";
    }
     
    int main() {
    	setIO("nondec");
    	cin >> N >> K; A.resize(N); 
    	for (int i = 0; i < N; ++i) cin >> A[i];
    	T STO, ISTO;
    	for (int i = 0; i < K; ++i) 
    		STO[i][i] = ISTO[i][i] = 1;
    	for (int i = 0; i <= N; ++i) {
    		for (int j = 0; j < K; ++j) 
    			for (int k = j; k < K; ++k) 
    				sto[i][j] = sto[i][j]+STO[j][k];
    		for (int k = 0; k < K; ++k) 
    			isto[i][k] = ISTO[0][k];
    		if (i == N) break;
    		int x = A[i]-1;
    		// STO goes from pre[i] to pre[i+1]
    		// set STO = STO*M_{A[i]}
    		for (int j = 0; j <= x; ++j) 
    			for (int k = x; k >= j; --k) 
    				STO[j][x] = STO[j][x]+STO[j][k];
    		// ISTO goes from ipre[i] to ipre[i+1]
    		// set ISTO=M_{A[i]}^{-1}*ISTO
    		for (int j = 0; j < x; ++j) 
    			for (int k = x; k < K; ++k)
    				ISTO[j][k] = ISTO[j][k]-i2*ISTO[x][k];
    		for (int k = x; k < K; ++k) 
    			ISTO[x][k] = ISTO[x][k]*i2;
    	}
    	cin >> Q;
    	for (int i = 0; i < Q; ++i) {
    		int L,R; cin >> L >> R;
    		mi ans = 0; 
    		for (int j = 0; j < K; ++j) 
    			ans = ans+isto[L-1][j]*sto[R][j];
    		cout << ans.v << "
    ";
    	}
    }
    

    我就不写了/kk;

  • 相关阅读:
    Qt状态机实例
    <STL> accumulate 与 自定义数据类型
    <STL> 容器混合使用
    散列表(C版)
    Canonical 要将 Qt 应用带入 Ubuntu
    <STL> set随笔
    C++ 文件流
    视频播放的基本原理
    <STL> pair随笔
    c++ 内存存储 解决char*p, char p[]的问题
  • 原文地址:https://www.cnblogs.com/pmt2018/p/12239401.html
Copyright © 2011-2022 走看看