zoukankan      html  css  js  c++  java
  • 「清华集训 2017」小 Y 和*的奴隶主

    弱化版

    为什么这里的题解都是写的顺推呢?这里提供一篇倒推的题解,为那些和我一样打倒退写wa的小伙伴提供一点能够借鉴的代码。

    对于这道弱化版的题目,我们考虑倒推,定义(f_{i,j,k,w})我们在被打了(i-1)次后,还剩下(j)个三血奴隶主,(k)个两血奴隶主,(w)个一血奴隶主在接下来的(i)~(k)次被打中受到伤害的期望。

    那么转移就会比较容易:

    定义(tot = j+k+w+1)

    (j+k+w < 7)时有:

    [f_{i,j,k,w} = (f_{i+1,j,k,w}+1) imes frac{1}{tot}+ f_{i+1,j,k+1,w} imes frac {j}{tot} + f_{i+1,j+1,k-1,w+1} imes frac{k}{tot} +f_{i+1,j,k,w-1} imes frac{w}{tot} ]

    否则有:

    [f_{i,j,k,w} = f_{i+1,j-1,k+1,w} imes frac {j}{tot} + f_{i+1,j,k-1,w+1} imes frac{k}{tot} +f_{i+1,j,k,w-1} imes frac{w}{tot} ]

    代码比较简单,注意一下输入顺序就好了。

    现在我们来看这一道题,我们仍然考虑倒退,发现(n leq 10^{18}),我们需要优化,我们发现对于每个(dp)状态的转移的参数与(i)并无关系,因此我们考虑矩阵加速。

    我们枚举出(j,k,w)的所有合法情况(由于我们转移的特殊性,我们需要再把1作为常数项加入矩阵参与转移),将其作为矩阵的大小,由于(K leq 8),因此状态的数量并不会太多,只有(166)个,我们枚举完合法情况将其编号过后枚举可以转移它的状态并计算出参数填入矩阵中,就可以进行矩阵优化了~。

    但是这个复杂度我们仍然无法通过本题,我们考虑继续优化。我们发现这个矩阵是固定的,因此我们可以预处理出矩阵的(2^i)的结果,我们在处理询问的时候只需要用一个单行去乘以我们预处理出来的一些矩阵就可以了。

    但是它还是死了,根据一些奇技淫巧,我们可以用__in128储存矩阵乘法的中间结果,最后再去取模,就能有效地减少取模次数,优化效果显著

    下面就是十分冗长的代码:

    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    using namespace std;
    typedef long long LL;
    const LL mod = 998244353;
     
    struct matrix {
    	LL a[170][170];
    	int n , m;
    	matrix operator * (const matrix &b) const{
    		matrix c;
    		__int128 C[170][170];
    		for (int i = 1; i <= n; ++i) {
    			for (int j = 1; j <= b.m; ++j) {
    				C[i][j] = 0;
    			}
    		}
    		for (int i = 1; i <= n; ++i) {
    			for (int k = 1; k <= m; ++k) {
    				for (int j = 1; j <= b.m; ++j) {
    					C[i][j] = C[i][j] + a[i][k] * b.a[k][j];
    				}
    			}
    		}
    		for (int i = 1; i <= n; ++i) {
    			for (int j = 1; j <= b.m; ++j) {
    				c.a[i][j] = C[i][j] % mod;
    			}
    		}
    		c.n = n;
    		c.m = b.m;
    		return c;
    	}
    }I; 
    int T , m , K , tot , id[10][10][10];
    matrix A , B[64] , beg , ans;
    LL get_inv(LL a) {
    	int b = mod - 2;
    	LL res = 1;
    	while(b) {
    		if(b & 1) res = res * a % mod;
    		a = a * a % mod;
    		b >>= 1;
    	}
    	return res;
    }
    void get_matrix_3 () {
    	for (int j = 0; j <= K; ++j) {
    		for (int k = 0; k <= K; ++k) {
    			for (int w = 0; w <= K; ++w) {
    				if(j + k + w > K) continue;
    				id[j][k][w] = ++tot;
    			}
    		}
    	}
    	++tot;
    	A.n = A.m = tot;
    	A.a[tot][tot] = 1;
    	for (int j = 0; j <= K; ++j) {
    		for (int k = 0; k <= K; ++k) {
    			for (int w = 0; w <= K; ++w) {
    				if(j + k + w > K) continue;
    				int now = id[j][k][w];
    				LL tmp = get_inv((LL)(j + k + w + 1));
    				A.a[now][now] = tmp;A.a[now][tot] = tmp;
    				if(j) {
    					if(j + k + w < K) A.a[now][id[j][k+1][w]] = (LL)j * tmp % mod;
    					else A.a[now][id[j-1][k+1][w]] = (LL)j * tmp % mod;
    				}
    				if(k) {
    					if(j + k + w < K) A.a[now][id[j+1][k-1][w+1]] = (LL)k * tmp % mod;
    					else A.a[now][id[j][k-1][w+1]] = (LL)k * tmp % mod;
    				}
    				if(w) A.a[now][id[j][k][w-1]] = (LL)w * tmp % mod; 
    			} 
    		}
    	}
    }
    int id2[10][10];
    void get_matrix_2 () {
    	for (int j = 0; j <= K; ++j) {
    		for (int k = 0; k <= K; ++k) {
    			if(j + k > K) continue;
    			id2[j][k] = ++tot;
    		}
    	}
    	++tot;
    	I.n = I.m = A.n = A.m = tot;
    	for (int j = 0; j <= K; ++j) {
    		for (int k = 0; k <= K; ++k) {
    			if(j + k > K) continue;
    			int now = id2[j][k];
    			LL tmp = get_inv((LL)(j + k + 1));
    			A.a[now][now] = tmp;A.a[now][tot] = tmp;
    			if(j) {
    				if(j + k < K) A.a[now][id2[j][k+1]] = (LL)j * tmp % mod;
    				else A.a[now][id2[j-1][k+1]] = (LL)j * tmp % mod;
    			}
    			if(k) A.a[now][id2[j][k-1]] = (LL)k * tmp % mod;
    		}
    	}
    	A.a[tot][tot] = 1;
    }
    void make_pow() {
    	B[0] = A;
    	for (int i = 1; i <= 59; ++i) B[i] = B[i - 1] * B[i - 1];
    }
    LL f[15][10];
    void work_1(int n) {
    	for (int i = n; i >= 1; --i) {
    		for (int j = 0; j <= K; ++j) {
    			LL tmp = get_inv((LL)(j + 1));
    			f[i][j] = (f[i+1][j] + 1) * tmp % mod;
    			if(j) f[i][j] += f[i+1][j-1] * (LL)j % mod * tmp % mod;
    		}
    	}
    	printf("%lld
    " , f[1][1]);
    	memset(f , 0 , sizeof f);
    }
    int main() {
    //	freopen("10.in" , "r" , stdin);
    	scanf("%d %d %d" , &T , &m , &K);
    	if(m == 3) {
    		get_matrix_3();
    		make_pow();
    		ans.n = tot;ans.m = 1;
    		while(T -- > 0) {
    			LL x;
    			scanf("%lld" , &x);
    			for (int i = 1; i < tot; ++i) ans.a[i][1] = 0;
    			ans.a[tot][1] = 1;
    			for (LL i = 60; i >= 1; --i) {
    				if(x >= (1ll << (i - 1ll))) {
    					x -= (1ll << (i - 1ll));
    					ans = B[i - 1] * ans;
    				}
    			}
    			printf("%lld
    " , ans.a[id[1][0][0]][1]);
    		}
    	}
    	else if(m == 2) {
    		get_matrix_2();
    		make_pow();
    		beg.n = tot;beg.m = 1;
    		beg.a[tot][1] = 1;
    		while(T -- > 0) {
    			LL x;
    			scanf("%lld" , &x);
    			ans = beg;
    			for (LL i = 60; i >= 1; --i) {
    				if(x >= (1ll << (i - 1ll))) {
    					x -= (1ll << (i - 1ll));
    					ans = B[i - 1] * ans;
    				}
    			}
    			printf("%lld
    " , ans.a[id2[1][0]][1]);
    		}
    	}
    	else {
    		while(T -- > 0) {
    			int n;
    			scanf("%d" , &n);
    			work_1(n);
    		}
    	}
    	return 0;
    }
    
  • 相关阅读:
    java 实现N进制转M进制
    BigInteger构造函数解析
    SpringBoot 实现前后端分离的跨域访问(CORS)
    python:[numpy] ndarray 与 list 互相转换
    PyTorch使用GPU的方法
    Matplotlib.pyplot 把画图保存为图片 指定图片大小
    python列表中的所有值转换为字符串,以及列表拼接成一个字符串
    python 读取中文文件名/中文路径
    在Python中使用LSTM和PyTorch进行时间序列预测(深度学习时序数据预测)
    记录分析python程序运行时间的几种方法
  • 原文地址:https://www.cnblogs.com/Reanap/p/13484732.html
Copyright © 2011-2022 走看看