zoukankan      html  css  js  c++  java
  • Codeforces 719E (线段树教做人系列) 线段树维护矩阵

    题面简洁明了,一看就懂

    做了这个题之后,才知道怎么用线段树维护递推式。递推式的递推过程可以看作两个矩阵相乘,假设矩阵A是初始值矩阵,矩阵B是变换矩阵,求第n项相当于把矩阵B乘了n - 1次。

    那么我们线段树中每个点维护把矩阵B乘了多少次,懒标记下放的时候用快速幂维护sum。

    #include <bits/stdc++.h>
    #define LL long long
    #define ls(x) (x << 1)
    #define rs(x) ((x << 1) | 1)
    using namespace std;
    const LL mod = 1000000007;
    const int maxn = 100010;
    struct Matrix {
    	static const int len = 2;
    	LL x[len][len];
    	
    	void init() {
    		memset(x, 0, sizeof(x));
    		for (int i = 0; i < len; i++)
    			x[i][i] = 1;
    	}
    	
    	void zero() {
    		memset(x, 0, sizeof(x));
    	}
    	
    	Matrix operator * (const Matrix& m) const {
    		Matrix ans;
    		ans.zero();
    		for (int i = 0; i < len; i++)
    			for (int j = 0; j < len; j++)
    				for (int k = 0; k < len; k++)
    					ans.x[i][j] = (ans.x[i][j] + x[i][k] * m.x[k][j]) % mod;
    		return ans;
    	}
    	
    	Matrix operator + (const Matrix& m) const {
    		Matrix ans;
    		ans.zero();
    		for (int i = 0; i < len; i++)
    			for (int j = 0; j < len; j++)
    				ans.x[i][j] = (x[i][j] + m.x[i][j]) % mod;
    		return ans;
    	}
    	
    	Matrix operator ^ (int b) const {
    		Matrix ans, a;
    		ans.init();
    		memcpy(a.x, x, sizeof(x));
    		for (; b; b >>= 1) {
    			if(b & 1) ans = ans * a;
    			a = a * a;
    		}
    		return ans;
    	}
    };
    
    Matrix mul , tmp, trans ;
    int a[maxn];
    struct SegementTree {
    	int lz;
    	Matrix sum, flag;
    };
    
    SegementTree tr[maxn * 4];
    
    void maintain(int o) {
    	tr[o].sum = tr[ls(o)].sum + tr[rs(o)].sum;
    }
    
    void pushdown(int o) {
    	if(tr[o].lz) {
    		tr[ls(o)].sum = tr[ls(o)].sum * tr[o].flag;
    		tr[rs(o)].sum = tr[rs(o)].sum * tr[o].flag;
    		tr[ls(o)].flag = tr[ls(o)].flag * tr[o].flag;
    		tr[rs(o)].flag = tr[rs(o)].flag * tr[o].flag;
    		tr[o].lz = 0;
    		tr[ls(o)].lz = 1;
    		tr[rs(o)].lz = 1;
    		tr[o].flag.init();
    	}
    }
    
    void build(int o, int l, int r) {
    	tr[o].sum.zero();
    	tr[o].lz = 0;
    	tr[o].flag.init();
    	if(l == r) {
    		tr[o].sum = trans * ( mul ^ (a[l] - 1));
    		return;
    	}
    	int mid = (l + r) >> 1;
    	build(ls(o), l, mid);
    	build(rs(o), mid + 1, r);
    	maintain(o);
    }
    
    void update(int o, int l, int r, int ql, int qr, Matrix now) {
    	if(l >= ql && r <= qr) {
    		tr[o].sum = tr[o].sum * now;
    		tr[o].flag = tr[o].flag * now;
    		tr[o].lz = 1;
    		return;
    	}
    	pushdown(o);
    	int mid = (l + r) >> 1;
    	if(ql <= mid) update(ls(o), l, mid, ql, qr, now);
    	if(qr > mid) update(rs(o), mid + 1, r, ql, qr, now);
    	maintain(o);
    }
    
    LL query(int o, int l, int r, int ql, int qr) {
    	if(l >= ql && r <= qr) {
    		return tr[o].sum.x[0][1];
    	}
    	pushdown(o);
    	int mid = (l + r) >> 1;
    	LL ans = 0;
    	if(ql <= mid) ans = (ans + query(ls(o), l, mid, ql, qr)) % mod;
    	if(qr > mid) ans = (ans + query(rs(o), mid + 1, r, ql, qr)) % mod;
    	return ans;
    }
    
    int main() {
    	int n, m, op, l, r;
    	LL x;
    	trans.zero();
    	trans.x[0][1] = 1;
    	mul.x[0][1] = mul.x[1][0] = mul.x[1][1] = 1;
    	mul.x[0][0] = 0;
    	scanf("%d%d", &n, &m);
    	for (int i = 1; i <= n; i++) {
    		scanf("%d", &a[i]);
    	}
    	build(1, 1, n);
    	for (int i = 1; i <= m; i++) {
    		scanf("%d%d%d", &op, &l, &r);
    		if(op == 1) {
    			scanf("%lld", &x);
    			tmp = (mul ^ x);
    			update(1, 1, n, l, r, tmp);
    		} else {
    			printf("%lld
    ", query(1, 1, n, l, r));
    		}
    	}
    }
    

      

  • 相关阅读:
    对数值计算numpy的一些总结,感兴趣的朋友可以看看
    mysql基础语法(部分)
    python_内建结构
    07_go语言基础
    06_go语言基础
    05_go语言基础常量
    04_go语言基础
    03_go语言基础
    02_go语言基础
    01_go语言基础
  • 原文地址:https://www.cnblogs.com/pkgunboat/p/10608454.html
Copyright © 2011-2022 走看看