zoukankan      html  css  js  c++  java
  • 2020牛客寒假算法基础集训营2 J-求函数 (线段树维护矩阵乘法)

    题目链接:https://ac.nowcoder.com/acm/contest/3003/J

    思路:

    方法①

    f1(1)=k1+b1=(k1)+(b1)

    f2(f1(1))=k2(f1(1))+b2=k2k1+k2b1+b2=(k2k1)+(k2b1+b2)

    f3(f2(f1(1)))=(k3k2k1)+(k3k2b1+k3b2+b3)

    通过上面的展开,我们可以发现一个式子可以分成两部分:∏Ki  与  ∑ri=l(bi*∏rj=i+1Kj)

    分别用线段树维护这两部分即可,现在考虑如果合并区[l, r] 与 [r1+1 ,r]

    假设左区间的第一部分为 N第二部分为 M1

      右区间的第一部分为 N2 第二部分为 M2

    合并后区间的第一部分为N1*N2,第二部分为N2 * M1 + M2

    #include<iostream>
    #include<algorithm>
    #include<cstring>
     using namespace std;
     typedef long long ll;
     const int mod=1e9+7;
     const int maxn=2e5+10;
     struct node{
         ll l,r,k,b;
     }tree[maxn<<2];
     ll k[maxn],b[maxn],n,m,op,l1,r1,po,k1,b1;
     void pushup(int rt)
     {
         tree[rt].k=(tree[rt<<1].k*tree[rt<<1|1].k)%mod;
         tree[rt].b=((tree[rt<<1|1].k*tree[rt<<1].b)%mod+tree[rt<<1|1].b)%mod;
     }
     void build(ll rt,ll l,ll r)
     {
         tree[rt].l=l;
         tree[rt].r=r;
         if(l==r){
             tree[rt].k=k[l],tree[rt].b=b[l];
             return;
         }
        ll mid=(l+r)>>1;
        build(rt<<1,l,mid);
        build(rt<<1|1,mid+1,r);
        pushup(rt);
     }
     void update(ll rt,ll pos)
     {
         if(tree[rt].l==pos&&tree[rt].r==pos){
             tree[rt].k=k[pos],tree[rt].b=b[pos];
             return;
         }
        ll mid=(tree[rt].l+tree[rt].r)>>1;
        if(pos<=mid)    update(rt<<1,pos);
        else update(rt<<1|1,pos);
        pushup(rt);
      }
      typedef pair<ll,ll> p;
      p query(int rt,int l,int r,int ll,int rr)
    {
        if(ll>r || rr<l) return p(-1,-1);
        if(l>=ll && r<=rr) return p(tree[rt].k,tree[rt].b);
        int mid=(l+r)>>1;
        p p1=query(rt<<1,l,mid,ll,rr);
        p p2=query(rt<<1|1,mid+1,r,ll,rr);
        if(p1.first==-1) return p2;
        if(p2.first==-1) return p1;
        int k1=p1.first,b1=p1.second;
        int k2=p2.first,b2=p2.second;
        return p(1ll*k1*k2%mod,(1ll*b1*k2+b2)%mod);
    }
     int main()
     {
         scanf("%lld%lld",&n,&m);
         for(int i=1;i<=n;i++) scanf("%d",&k[i]);
         for(int i=1;i<=n;i++) scanf("%d",&b[i]);
         build(1,1,n);
         for(int i=1;i<=m;i++){
             scanf("%lld",&op);
             if(op==1){
                 scanf("%lld%lld%lld",&po,&k1,&b1);
                k[po]=k1;
                b[po]=b1;
                update(1,po);
             } 
             else{
                 scanf("%lld%lld",&l1,&r1);
                 p p1=query(1,1,n,l1,r1);
                 int k=p1.first,b=p1.second;
                printf("%d
    ",((k+b)%mod+mod)%mod);
             } 
         }
        return 0;
     }

     方法②矩阵乘法

    这篇博客讲的不错:https://www.cnblogs.com/BakaCirno/p/12270838.html

    #include<iostream>
    #include<iostream>
    #include<cstring>
    #define mid ((l+r)>>1)
     typedef long long ll;
     using namespace std;
     const int maxn=2e5+10;
     const int mod=1e9+7;
     int n,m;
     ll k[maxn],b[maxn];
     struct MX{
         ll m[2][2];
         MX(){memset(m,0,sizeof(m));}
        friend MX operator *(const MX&a,const MX&b){
            MX res;
            for(int i=0;i<2;i++)
                for(int j=0;j<2;j++){
                    for(int k=0;k<2;k++)
                        res.m[i][j]+=a.m[i][k]*b.m[k][j];
                    res.m[i][j]%=mod;
                }
            return res;
        }
     }mx[maxn<<2];
     void update(int rt,int l,int r,int pos)
     {
         if(l==r){mx[rt].m[0][0]=k[l],mx[rt].m[1][0]=b[l],mx[rt].m[1][1]=1;return;}
         if(pos<=mid)    update(rt<<1,l,mid,pos);
         else update(rt<<1|1,mid+1,r,pos);
         mx[rt]=mx[rt<<1]*mx[rt<<1|1];
     }
     MX query(int rt,int l,int r,int L,int R)
     {
         if(L<=l&&r<=R)    return mx[rt];
         MX res;
         res.m[0][0]=res.m[1][1]=1;
         if(L<=mid)    res=res*query(rt<<1,l,mid,L,R);
         if(R>mid)    res=res*query(rt<<1|1,mid+1,r,L,R);
         return res;
     }
     int main()
     {
         cin>>n>>m;
         for(int i=1;i<=n;i++)    scanf("%lld",&k[i]);
         for(int i=1;i<=n;i++)    scanf("%lld",&b[i]);
         for(int i=1;i<=n;i++)    update(1,1,n,i);
         for(int i=1,opt,l,r;i<=m;i++){
             cin>>opt;
             if(opt==1){
                 scanf("%d",&l);
                 scanf("%lld%lld",&k[l],&b[l]);
                 update(1,1,n,l);
             }
             else{
                 scanf("%d%d",&l,&r);
                 MX res;
                 res.m[0][0]=res.m[0][1]=1;
                 res=res*query(1,1,n,l,r);
                 printf("%lld
    ",res.m[0][0]%mod);
             }
         }
      } 
  • 相关阅读:
    剑指offer 整数中1出现的次数(从1到n整数中1出现的次数)
    剑指offer 把数组排成最小的数
    剑指offer 丑数
    剑指offer 字符串的排列
    剑指offer 数组中出现次数超过一半的数字
    剑指offer 最小的K个数
    操作系统 页面置换算法(C++实现)
    剑指offer 二叉搜索树与双向链表
    剑指offer 复杂链表的复制
    操作系统 银行家算法(C++实现)
  • 原文地址:https://www.cnblogs.com/overrate-wsj/p/12274408.html
Copyright © 2011-2022 走看看