zoukankan      html  css  js  c++  java
  • 【BZOJ4944】【NOI2017】泳池 概率DP 常系数线性递推 特征多项式 多项式取模

    题目大意

      有一个(1001 imes n)的的网格,每个格子有(q)的概率是安全的,(1-q)的概率是危险的。

      定义一个矩形是合法的当且仅当:

    • 这个矩形中每个格子都是安全的
    • 必须紧贴网格的下边界

      问你最大的合法子矩形大小为(k)的概率是多少。

      (nleq {10}^9,kleq 1000)

      吉老师:这题本来是(kleq 20000)

    题解

      一道好题。

      我们计算最大子矩形不超过(i)的答案(s_i),那么答案就是(s_k-s_{k-1})

      显然最后一行连续的安全格子不会超过(k)个。

      设(g_{i,j})表示长度为(j),高度为(i)的海域全部是安全的,剩下的部分未知,最大子矩形(leq k)的概率。

      设(h_{i,j})表示长度为(j),高度为(i+1)的海域中,前(i)行全部是安全的,剩下的未知且((i+1,j))是危险的,最大子矩形(leq k)的概率。

      边界:

    [egin{align} g_{k,1}&=q^k(1-q)\ g_{i,0}&=1\ h_{i,0}&=1 end{align} ]

      那么我们从(k-1)(1)DP,对于(i)(j)列,枚举第(i+1)行的下一个危险的格子在哪个地方,然后转移:

    [egin{align} g_{i,j}&=sum_{k=0}^{j}h_{i,k}g_{i+1,j-k}\ h_{i,j}&=sum_{k=0}^{j-1}h_{i,k}g_{i+1,j-k-1}q^i(1-q) end{align} ]

      因为第(i)行的宽度不会超过(lfloorfrac{k}{i} floor),所以的暴力的时间复杂度是(sum_{i=1}^k{lfloorfrac{k}{i} floor}^2=O(k^2))

      这已经足够了,但我们可以做的更好。

      设

    [egin{align} A_i(x)&=sum_{jgeq 0}g_{i,j}x^j\ B_i(x)&=sum_{jgeq 0}h_{i,j}x^j\ c_i&=q^i(1-q)\ end{align} ]

    那么

    [egin{align} A_i(x)&=B_i(x)A_{i+1}(x)\ B_i(x)&=c_ixA_{i+1}(x)B_i(x)+1\ B_i(x)&=frac{1}{1-c_ixA_{i+1}(x)}\ end{align} ]

      时间复杂度是(sum_{i=1}^klfloorfrac{k}{i} floorloglfloorfrac{k}{i} floor=O(klog^2k))

      设(f_i)为前(i)列最大子矩形(leq k)的概率,那么

    [f_i=sum_{j=1}^kf_{i-j-1}g_{1,j}(1-q) ]

      这就是一个常系数线性递推。

    [egin{align} a_i&=g_{1,i-1}(1-q)\ f_i&=sum_{j=1}^kf_{i-j}a_j end{align} ]

      时间复杂度:

    • 暴力:(O(nk))(70)pts
    • 矩阵快速幂:(O(k^3log n))(90)pts
    • 特征多项式+暴力:(O(k^2log n))(100)pts
    • 特征多项式+NTT取模:(O(klog klog n))(100)pts

      这里简单讲一下最后一个做法

      矩阵快速幂是给你一个矩阵(A),求((A^n)_{1,1})

      设矩阵的大小为(k)

      根据Cayley-Hamilton定理,(|lambda I-A|)是一个关于(lambda)(k)次多项式,记为(g(lambda))。对于任意矩阵(A),有(g(A)=0)

      对于常系数线性递推的矩阵,设(f_i=sum_{j=1}^kf_{i-j}a_j)(g(lambda)=lambda^k-sum_{i=1}^{k}a_{i}lambda^{k-i})

      所以我们只需要求(A^nmod g(A))。可以用快速幂(倍增取模)求解。

      然后还要求出(f_1ldots f_k),可以通过其他方法计算(多项式求逆或者题目给你了)。

      最后一次卷积可以得到答案。

      如果要求(f_{n-k+1}ldots f_n),那就把(f_1ldots f_{2k})带进去卷积。

      总时间复杂度:(O(klog^2k+klog klog n))

    代码

      暴力取模

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cstdlib>
    #include<ctime>
    #include<utility>
    #include<cmath>
    #include<functional>
    using namespace std;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef pair<int,int> pii;
    typedef pair<ll,ll> pll;
    void sort(int &a,int &b)
    {
    	if(a>b)
    		swap(a,b);
    }
    void open(const char *s)
    {
    #ifndef ONLINE_JUDGE
    	char str[100];
    	sprintf(str,"%s.in",s);
    	freopen(str,"r",stdin);
    	sprintf(str,"%s.out",s);
    	freopen(str,"w",stdout);
    #endif
    }
    int rd()
    {
    	int s=0,c;
    	while((c=getchar())<'0'||c>'9');
    	do
    	{
    		s=s*10+c-'0';
    	}
    	while((c=getchar())>='0'&&c<='9');
    	return s;
    }
    int upmin(int &a,int b)
    {
    	if(b<a)
    	{
    		a=b;
    		return 1;
    	}
    	return 0;
    }
    int upmax(int &a,int b)
    {
    	if(b>a)
    	{
    		a=b;
    		return 1;
    	}
    	return 0;
    }
    ll p=998244353;
    void add(ll &a,ll b)
    {
    	a=(a+b)%p;
    }
    ll fp(ll a,ll b)
    {
    	ll s=1;
    	for(;b;b>>=1,a=a*a%p)
    		if(b&1)
    			s=s*a%p;
    	return s;
    }
    ll inv(ll a)
    {
    	return fp(a,p-2);
    }
    ll pw1[1010];
    ll pw2[1010];
    ll q;
    ll q2;
    ll g[1010][1010];
    ll h[1010][1010];
    ll f[2010];
    ll a[2010];
    ll c[2010];
    ll d[2010];
    ll final[2010];
    void mul(ll *a,ll *b,ll *e,int len)
    {
    	static ll c[2010];
    	int i,j;
    	for(i=0;i<=2*len;i++)
    		c[i]=0;
    	for(i=0;i<=len;i++)
    		for(j=0;j<=len;j++)
    			add(c[i+j],a[i]*b[j]);
    	for(i=2*len;i>=len;i--)
    	{
    		ll v=c[i]*inv(e[len]);
    		if(v)
    			for(j=0;j<=len;j++)
    				c[i-len+j]=(c[i-len+j]-e[j]*v)%p;
    	}
    	for(i=0;i<=len;i++)
    		a[i]=c[i];
    }
    ll solve(int n,int k)
    {
    	if(!k)
    		return fp(q2,n);
    	memset(g,0,sizeof g);
    	memset(h,0,sizeof h);
    	g[k][1]=q2*pw1[k]%p;
    	g[k][0]=1;
    	int i,j,l;
    	for(i=k-1;i>=1;i--)
    	{
    		int m=k/i;
    		g[i][0]=1;
    		h[i][0]=1;
    		for(j=0;j<=m;j++)
    		{
    			for(l=j+1;l<=m;l++)
    				add(h[i][l],h[i][j]*g[i+1][l-j-1]%p*q2%p*pw1[i]%p);
    			for(l=j;l<=m;l++)
    				if(l)
    					add(g[i][l],h[i][j]*g[i+1][l-j]%p);
    		}
    	}
    	memset(f,0,sizeof f);
    	f[0]=1;
    	for(i=1;i<=2*(k+1);i++)
    		for(j=0;j<i&&j<=k;j++)
    			add(f[i],f[i-j-1]*q2%p*g[1][j]);
    	if(n<=2*(k+1))
    	{
    		ll s=0;
    		for(i=0;i<=n&&i<=k;i++)
    			add(s,f[n-i]*g[1][i]);
    		return s;
    	}
    	int len=k+1;
    	for(i=0;i<len;i++)
    		a[i]=-q2*g[1][len-i-1]%p;
    	a[len]=1;
    	memset(c,0,sizeof c);
    	c[1]=1;
    	memset(d,0,sizeof d);
    	d[0]=1;
    	int m=n-k-1;
    	while(m)
    	{
    		if(m&1)
    			mul(d,c,a,len);
    		mul(c,c,a,len);
    		m>>=1;
    	}
    	memset(final,0,sizeof final);
    	for(i=1;i<=k+1;i++)
    		for(j=0;j<=k;j++)
    			add(final[i],d[j]*f[i+j]);
    	ll s=0;
    	for(i=1;i<=k+1;i++)
    		add(s,final[i]*g[1][k+1-i]);
    	return s;
    }
    int main()
    {
    	open("bzoj4944");
    	int n,k,x,y;
    	scanf("%d%d%d%d",&n,&k,&x,&y);
    	q=x*inv(y)%p;
    	q2=(y-x)*inv(y)%p;
    	pw1[0]=pw2[0]=1;
    	int i;
    	for(i=1;i<=k;i++)
    	{
    		pw1[i]=pw1[i-1]*q%p;
    		pw2[i]=pw2[i-1]*q2%p;
    	}
    	ll ans1=solve(n,k);
    	ll ans2=solve(n,k-1);
    	ll ans=((ans1-ans2)%p+p)%p;
    	printf("%lld
    ",ans);
    	return 0;
    }
    

      NTT取模

    #include<cstdio>
    #include<cstring>
    #include<algorithm>
    #include<cstdlib>
    #include<ctime>
    #include<utility>
    #include<cmath>
    #include<functional>
    using namespace std;
    typedef long long ll;
    typedef unsigned long long ull;
    typedef pair<int,int> pii;
    typedef pair<ll,ll> pll;
    void sort(int &a,int &b)
    {
    	if(a>b)
    		swap(a,b);
    }
    void open(const char *s)
    {
    #ifndef ONLINE_JUDGE
    	char str[100];
    	sprintf(str,"%s.in",s);
    	freopen(str,"r",stdin);
    	sprintf(str,"%s.out",s);
    	freopen(str,"w",stdout);
    #endif
    }
    int rd()
    {
    	int s=0,c;
    	while((c=getchar())<'0'||c>'9');
    	do
    	{
    		s=s*10+c-'0';
    	}
    	while((c=getchar())>='0'&&c<='9');
    	return s;
    }
    int upmin(int &a,int b)
    {
    	if(b<a)
    	{
    		a=b;
    		return 1;
    	}
    	return 0;
    }
    int upmax(int &a,int b)
    {
    	if(b>a)
    	{
    		a=b;
    		return 1;
    	}
    	return 0;
    }
    const ll p=998244353;
    const int maxn=300000;
    ll fp(ll a,ll b)
    {
    	ll s=1;
    	for(;b;b>>=1,a=a*a%p)
    		if(b&1)
    			s=s*a%p;
    	return s;
    }
    namespace ntt
    {
    	const ll g=3;
        ll w1[maxn];
        ll w2[maxn];
        int rev[maxn];
        int n;
        void init(int m)
        {
            n=1;
            while(n<m)
                n<<=1;
            int i;
            for(i=2;i<=n;i<<=1)
            {
                w1[i]=fp(g,(p-1)/i);
                w2[i]=fp(w1[i],p-2);
            }
            rev[0]=0;
            for(i=1;i<n;i++)
                rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
        }
        void ntt(ll *a,int t)
        {
            int i,j,k;
            ll u,v,w,wn;
            for(i=0;i<n;i++)
                if(rev[i]<i)
                    swap(a[i],a[rev[i]]);
            for(i=2;i<=n;i<<=1)
            {
                wn=(t==1?w1[i]:w2[i]);
                for(j=0;j<n;j+=i)
                {
                    w=1;
                    for(k=j;k<j+i/2;k++)
                    {
                        u=a[k];
                        v=a[k+i/2]*w%p;
    					a[k]=(u+v)%p;
    					a[k+i/2]=(u-v)%p;
                        w=w*wn%p;
                    }
                }
            }
            if(t==-1)
            {
                u=fp(n,p-2);    
                for(i=0;i<n;i++)
                    a[i]=a[i]*u%p;
            }
        }
        ll x[maxn];
        ll y[maxn];
        ll z[maxn];
        void copy_clear(ll *a,ll *b,int m)
        {
            int i;
            for(i=0;i<m;i++)
                a[i]=b[i];
            for(i=m;i<n;i++)
                a[i]=0;
        }
        void copy(ll *a,ll *b,int m)
        {
            int i;
            for(i=0;i<m;i++)
                a[i]=b[i];
        }
        void mul(ll *a,ll *b,ll *c,int m)
        {
        	init(m<<1);
        	copy_clear(x,a,m);
        	copy_clear(y,b,m);
        	ntt(x,1);
        	ntt(y,1);
        	int i;
        	for(i=0;i<n;i++)
        		x[i]=x[i]*y[i]%p;
        	ntt(x,-1);
        	copy(c,x,m);
        }
        void inverse(ll *a,ll *b,int m)
        {
            if(m==1)
            {
                b[0]=fp(a[0],p-2);
                return;
            }
            inverse(a,b,m>>1);
            init(m<<1);
            copy_clear(x,a,m);
            copy_clear(y,b,m>>1);
            ntt(x,1);
            ntt(y,1);
            int i;
            for(i=0;i<n;i++)
                x[i]=y[i]*(2-x[i]*y[i]%p)%p;
        	ntt(x,-1);
        	copy(b,x,m);
        }
        ll c[maxn],d[maxn],e[maxn],f[maxn];
        void sqrt(ll *a,ll *b,int m)
        {
        	if(m==1)
        	{
        		if(a[0]==1)
        			b[0]=1;
        		else if(a[0]==0)
        			b[0]=0;
        		else
        			//我也不会
    				;
    			return;
    		}
    		sqrt(a,b,m>>1);
    //		copy_clear(c,b,m>>1);
    		int i;
    		for(i=m;i<m<<1;i++)
    			b[i]=0;
    		inverse(b,d,m);
    		init(m<<1);
    		for(i=m;i<m<<1;i++)
    			b[i]=d[i]=0;
    		ll inv2=fp(2,p-2);
    		copy_clear(x,a,m);
    		ntt(x,1);
    		ntt(d,1);
    		for(i=0;i<n;i++)
    			x[i]=x[i]*d[i]%p;
    		ntt(x,-1);
    		for(i=0;i<m;i++)
    			b[i]=((b[i]+x[i])%p*inv2)%p;
    	}
        void derivative(ll *a,ll *b,int m)
    	{
    		int i;
    		for(i=0;i<m-1;i++)
    			b[i]=(i+1)*a[i+1]%p;
    		b[m-1]=0;
    	}
        void differential(ll *a,ll *b,int m)
        {
    //    	int i;
    //    	for(i=m-1;i>=1;i--)
    //    		b[i]=a[i-1]*inv[i]%p;
        	b[0]=0;
        }
        void ln(ll *a,ll *b,int m)
        {
        	static ll c[maxn],d[maxn];
        	derivative(a,c,m);
        	inverse(a,d,m);
        	init(m<<1);
        	int i;
        	for(i=m;i<n;i++)
        		c[i]=d[i]=0;
        	ntt(c,1);
        	ntt(d,1);
        	for(i=0;i<n;i++)
        		c[i]=c[i]*d[i]%p;
        	ntt(c,-1);
        	differential(c,b,m);
        }
        void exp(ll *a,ll *b,int m)
        {
        	if(m==1)
        	{
        		b[0]=1;
        		return;
        	}
        	exp(a,b,m>>1);
        	int i;
        	for(i=m>>1;i<m;i++)
        		b[i]=0;
        	ln(b,y,m);
        	init(m<<1);
        	copy_clear(x,a,m);
        	x[0]++;
        	for(i=0;i<m;i++)
        		x[i]=(x[i]-y[i])%p;
        	copy_clear(y,b,m);
        	ntt(x,1);
        	ntt(y,1);
        	for(i=0;i<n;i++)
        		x[i]=x[i]*y[i]%p;
        	ntt(x,-1);
        	copy(b,x,m);
        }
        void module(ll *a,ll *b,ll *c,int n1,int n2)
        {
        	int k=1;
        	while(k<=n1-n2+1)
        		k<<=1;
        	int i;
        	for(i=0;i<=n1;i++)
        		d[i]=a[i];
        	for(i=0;i<=n2;i++)
        		e[i]=b[i];
        	reverse(d,d+n1+1);
        	reverse(e,e+n2+1);
        	for(i=n1-n2+1;i<k<<1;i++)
        		d[i]=e[i]=0;
        	inverse(e,f,k);
        	for(i=n1-n2+1;i<k<<1;i++)
        		f[i]=0;
        	init(k<<1);
        	ntt::ntt(d,1);
        	ntt::ntt(f,1);
        	for(i=0;i<n;i++)
        		e[i]=d[i]*f[i]%p;
        	ntt::ntt(e,-1);
        	for(i=0;i<=n1-n2;i++)
        		c[i]=e[i];
        	reverse(c,c+n1-n2+1);
        }
    };
    void add(ll &a,ll b)
    {
    	a=(a+b)%p;
    }
    ll inv(ll a)
    {
    	return fp(a,p-2);
    }
    ll pw1[maxn];
    ll pw2[maxn];
    ll q;
    ll q2;
    ll f[maxn];
    ll a[maxn];
    ll c[maxn];
    ll d[maxn];
    ll final[maxn];
    ll g[2][maxn];
    ll h[maxn];
    ll e[maxn];
    
    void mul(ll *a,ll *b,ll *c,int n)
    {
    	static ll d[maxn],e[maxn];
    	int k=1;
    	while(k<=n)
    		k<<=1;
    	ntt::init(k<<1);
    	int i;
    	for(i=0;i<k<<1;i++)
    		d[i]=e[i]=0;
    	for(i=0;i<=n;i++)
    	{
    		d[i]=a[i];
    		e[i]=b[i];
    	}
    	ntt::ntt(d,1);
    	ntt::ntt(e,1);
    	for(i=0;i<k<<1;i++)
    		d[i]=d[i]*e[i]%p;
    	ntt::ntt(d,-1);
    	//d=a*b
    	for(i=0;i<k<<1;i++)
    		e[i]=0;
    	int n2=(k<<1)-1;
    	while(!d[n2])
    		n2--;
    	ntt::module(d,c,e,n2,n);
    	for(i=0;i<n;i++)
    		a[i]=d[i];
    	for(i=0;i<k;i++)
    		d[i]=c[i];
    	for(i=k;i<k<<1;i++)
    		d[i]=0;
    	ntt::init(k<<1);
    	ntt::ntt(d,1);
    	ntt::ntt(e,1);
    	for(i=0;i<k<<1;i++)
    		d[i]=d[i]*e[i]%p;
    	ntt::ntt(d,-1);
    	for(i=0;i<n;i++)
    		a[i]=(a[i]-d[i])%p;
    }
    void powmod(ll *a,ll *b,ll *c,int m,int n)
    {
    	if(!n)
    		return;
    	powmod(a,b,c,m,n>>1);
    	mul(a,a,c,m);
    	if(n&1)
    		mul(a,b,c,m);
    }
    ll solve(int n,int k)
    {
    	memset(g,0,sizeof g);
    	memset(h,0,sizeof h);
    	int now=0;
    	g[now][1]=q2*pw1[k]%p;
    	g[now][0]=1;
    	h[0]=1;
    	int i,j;
    	for(i=k-1;i>=1;i--)
    	{
    		now^=1;
    		int m=k/i;
    		ll c=q2*pw1[i]%p;
    		int len=1;
    		while(len<=m)
    			len<<=1;
    		for(j=1;j<len;j++)
    			e[j]=-c*g[now^1][j-1];
    		e[0]=1;
    		ntt::inverse(e,h,len);
    		for(j=m+1;j<len<<1;j++)
    			h[j]=0;
    		ntt::init(len<<1);
    		ntt::ntt(g[now^1],1);
    		ntt::ntt(h,1);
    		for(j=0;j<len<<1;j++)
    			g[now][j]=g[now^1][j]*h[j]%p;
    		ntt::ntt(g[now],-1);
    		for(j=m+1;j<len<<1;j++)
    			g[now][j]=0;
    	}
    	memset(a,0,sizeof a);
    	for(i=0;i<=k;i++)
    		a[i+1]=-g[now][i]*q2%p;
    	a[0]=1;
    	int len=1;
    	while(len<=k+1)
    		len<<=1;
    	ntt::inverse(a,f,len<<1);
    	if(n<=2*(k+1))
    	{
    		ll s=0;
    		for(i=0;i<=n&&i<=k;i++)
    			add(s,f[n-i]*g[now][i]);
    		return s;
    	}
    	memset(a,0,sizeof a);
    	memset(c,0,sizeof c);
    	memset(d,0,sizeof d);
    	for(i=0;i<=k;i++)
    		a[i]=-g[now][k-i]*q2%p;
    	a[k+1]=1;
    	if(k)
    		c[1]=1;
    	else
    		c[0]=-a[0];
    	d[0]=1;
    	int m=n-k;
    	powmod(d,c,a,k+1,m);
    //	while(m)
    //	{
    //		if(m&1)
    //			mul(d,c,a,k+1);
    //		mul(c,c,a,k+1);
    //		m>>=1;
    ////		for(i=0;i<=k;i++)
    ////			printf("%lld ",(d[i]+p)%p);
    ////		printf("
    ");
    //	}
    	reverse(d,d+k+1);
    	ntt::init(len<<2);
    	ntt::ntt(d,1);
    	ntt::ntt(f,1);
    	for(i=0;i<len<<2;i++)
    		final[i]=d[i]*f[i]%p;
    	ntt::ntt(final,-1);
    	ll s=0;
    	for(i=0;i<=k;i++)
    		add(s,g[now][i]*final[2*k-i]);
    	return s;
    //	for(i=0;i<=k;i++)
    //		g[now][i]=(g[now][i]+p)%p;
    //	memset(f,0,sizeof f);
    //	f[0]=1;
    //	for(i=1;i<=2*(k+1);i++)
    //		for(j=0;j<i&&j<=k;j++)
    //			add(f[i],f[i-j-1]*q2%p*g[now][j]);
    //	if(n<=2*(k+1))
    //	{
    //		ll s=0;
    //		for(i=0;i<=n&&i<=k;i++)
    //			add(s,f[n-i]*g[now][i]);
    //		return s;
    //	}
    //	int len=k+1;
    //	for(i=0;i<len;i++)
    //		a[i]=-q2*g[now][len-i-1]%p;
    //	a[len]=1;
    //	memset(c,0,sizeof c);
    //	c[1]=1;
    //	memset(d,0,sizeof d);
    //	d[0]=1;
    //	int m=n-k-1;
    //	while(m)
    //	{
    //		if(m&1)
    //			mul(d,c,a,len);
    //		mul(c,c,a,len);
    //		m>>=1;
    //	}
    //	memset(final,0,sizeof final);
    //	for(i=1;i<=k+1;i++)
    //		for(j=0;j<=k;j++)
    //			add(final[i],d[j]*f[i+j]);
    //	ll s=0;
    //	for(i=1;i<=k+1;i++)
    //		add(s,final[i]*g[now][k+1-i]);
    //	return s;
    }
    int main()
    {
    	open("bzoj4944");
    	int n,k,x,y;
    	scanf("%d%d%d%d",&n,&k,&x,&y);
    	q=x*inv(y)%p;
    	q2=(y-x)*inv(y)%p;
    	pw1[0]=pw2[0]=1;
    	int i;
    	for(i=1;i<=k;i++)
    	{
    		pw1[i]=pw1[i-1]*q%p;
    		pw2[i]=pw2[i-1]*q2%p;
    	}
    	ll ans1=solve(n,k);
    	ll ans2=solve(n,k-1);
    	ll ans=((ans1-ans2)%p+p)%p;
    	printf("%lld
    ",ans);
    	return 0;
    }
    
  • 相关阅读:
    运算符
    练习
    JAVA学习日报 9/23
    JAVA学习日报 8.22
    JAVA学习日报 8.21
    第一节:SpringMVC 处理请求数据【1】
    第六节:@RequestMapping 映射请求占位符 @PathVariable 注解
    第一节:REST 风格的URL地址约束
    第二节:REST风格的案例及源码分析
    (一)IOC 容器:【1】@Configuration&@Bean 给容器中注册组件
  • 原文地址:https://www.cnblogs.com/ywwyww/p/8513404.html
Copyright © 2011-2022 走看看