zoukankan      html  css  js  c++  java
  • 2018焦作现场赛 H. Can You Solve the Harder Problem?(后缀数组+rmq/线段树+单调栈)

    题意

    在一个数组中,求所有本质不同子段的贡献和。

    每个子段的贡献为该子段中的最大值。

    (n leq 2e5 , T leq 1000)

    传送门

    思路

    首先子段的贡献是子段中最大值,所以不难转化为求每个最大值对答案的贡献:

    (nxt[i]) 代表 (min{j|i<j &&j<=n+1 && a[j] > a[i] })

    则贡献 (suf[i] = a[i] * (nxt[i] - i) + suf[nxt[i]]), 对于 (nxt[i]) 的求解可用单调栈。

    之后对于本质不同的子串,考虑用后缀数组。后缀数组对后缀排完序之后可以求其排名为 (i) 的子段与排名为 (i-1) 子段的 lcp值 $height [ i ] $该子段的贡献即为 ([height[i]+1, n]) 这么一个区间的贡献。

    对于(height[i] == 0) 的子段,子段对答案的贡献为整个子段的贡献;

    对于 (height[i]!=0) 的子段,则字段对答案的贡献可以分成两段考虑:

    首先设 (p)([sa[i], sa[i]+height[i]]) 中最大值的下标,查询可用rmq或者线段树,以(nxt[p])为分界点将([sa[i] + height[i]+1, n]) 分成两部分:

    ((1):[sa[i]+height[i], nxt[i]-1]) : 易得贡献 (a[p]*(nxt[p]-sa[i]-height[i]-1))

    ((2):[nxt[p], n]):贡献为该子段的贡献。

    最后可以对数组中的值进行离散化,提高效率。

    Code

    #include <bits/stdc++.h>
    
    using namespace std;
    
    typedef long long ll;
    const int inf = 0x3f3f3f3f;
    const int maxn = 1e6+10;
    
    int T, n, s[maxn], has[maxn], pn;
    struct SuffixArray {
        int x[maxn], y[maxn], c[maxn];
        int sa[maxn], rk[maxn], height[maxn];
    
        void SA() {
            int m = pn;
            for (int i = 0; i <= m; ++i) c[i] = 0;
            for (int i = 1; i <= n; ++i) ++c[(x[i]=s[i])];
            for (int i = 1; i <= m; ++i) c[i] += c[i-1];
            for (int i = n; i >= 1; --i) sa[c[x[i]]--] = i;
    
            for (int p, k = 1; k <= n; k <<= 1) {
                p = 0;
                for (int i = n-k+1; i <= n; ++i) y[++p] = i ;
                for (int i = 1; i <= n; ++i) {
                    if(sa[i] > k) y[++p] = sa[i] - k;
                }
    
                for (int i = 0; i <= m; ++i) c[i] = 0;
                for (int i = 1; i <= n; ++i) ++c[x[y[i]]];
                for (int i = 1; i <= m; ++i) c[i] += c[i-1];
                for (int i = n; i >= 1; --i) sa[c[x[y[i]]]--] = y[i];
    
                p = y[sa[1]] = 1;
                for (int i = 2; i <= n; ++i) {
                    int a = sa[i]+k > n? -1 : x[sa[i]+k];
                    int b = sa[i-1]+k > n ? -1: x[sa[i-1]+k];
                    y[sa[i]] = (x[sa[i]] == x[sa[i-1]] && a == b) ? p : ++p;
                }
                swap(x, y);
                if(p >= n) break;
                m = p;
            }
            for (int i = 1; i <= n; ++i) rk[sa[i]] = i;
        }
    
        void getHeight() {
            for (int k = 0, i = 1; i <= n; ++i) {
                if(k) --k;
                int j = sa[rk[i]-1];
                while(s[i+k] == s[j+k]) ++k;
                height[rk[i]] = k;
            }
        }
    
        void build() {
            SA();
            getHeight();
        }
    
        void write() {
            for (int i = 1; i <= n; ++i) printf("%d ", sa[i]); puts("");
            for (int i = 1; i <= n; ++i) printf("%d ", rk[i]); puts("");
            for (int i = 1; i <= n; ++i) printf("%d ", height[i]); puts("");
        }
    }sa;
    
    int sta[maxn], top;
    int nxt[maxn], st[maxn][30];
    ll suf[maxn];
    
    int query(int l, int r) {
        int len = r - l + 1;
        int d = 0;
        while((1<<d+1) <= len) ++d;
        int p = 1<<d;
        if(s[st[l][d]] > s[st[r-p+1][d]]) return st[l][d];
        return st[r-p+1][d];
    }
    
    int main() {
    //    freopen("input.in", "r", stdin);
        scanf("%d", &T);
        while(T--) {
            scanf("%d", &n);
            for (int i = 1; i <= n; ++i) {
                scanf("%d", s+i);
                has[i] = s[i];
            }
    
            sort(has+1, has+1+n);
            pn = unique(has+1, has+1+n) - has - 1;
            for (int i = 1; i <= n; ++i) {
                s[i] = lower_bound(has+1, has+1+pn, s[i]) - has;
                st[i][0] = i;
            }
    
            s[n+1] = inf, sta[(top = 1)] = n+1;
            for (int i = n; i >= 1; --i) {
                while(top && s[sta[top]] <= s[i]) --top;
                nxt[i] = sta[top];
                sta[++top] = i;
            }
            nxt[n+1] = st[n + 1][0] = n + 1;
    
            for (int j = 1; j <= 20; ++j) {
                int p = 1<<j-1, l = (1<<j)-1;
                for (int i = 1; i + l <= n; ++i) {
                    if(s[st[i][j - 1]] > s[st[i + p][j - 1]]) st[i][j] = st[i][j - 1];
                    else st[i][j] = st[i + p][j - 1];
                }
            }
    
            suf[n+1] = 0;
            for (int i = n; i >= 1; --i) suf[i] = 1ll * has[s[i]] * (nxt[i] - i);
            for (int i = n; i >= 1; --i) suf[i] += suf[nxt[i]];
    
            sa.build();
    
            ll ans = 0;
            for (int i = 1; i <= n; ++i) {
                int h = sa.height[i];
                if(h == 0) {
                    ans += suf[sa.sa[i]];
                } else {
                    int r = sa.sa[i] + h - 1;
                    int l = sa.sa[i];
                    int pos = query(l, r);
                    ans += suf[nxt[pos]] + 1ll * has[s[pos]] * (nxt[pos] - r - 1);
                }
            }
            printf("%lld
    ", ans);
        }
        return 0;
    }
    

    水先生博客中的神仙读入挂优化后的代码:

    #include <bits/stdc++.h>
     
    using namespace std;
     
    typedef long long ll;
    const int inf = 0x3f3f3f3f;
    const int maxn = 2e5+10;
     
    namespace fastIO{
    #define BUF_SIZE 100000
    #define OUT_SIZE 100000
    #define ll long long
        //fread->read
        bool IOerror=0;
        inline char nc(){
            static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
            if (p1==pend){
                p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin);
                if (pend==p1){IOerror=1;return -1;}
                //{printf("IO error!
    ");system("pause");for (;;);exit(0);}
            }
            return *p1++;
        }
        inline bool blank(char ch){return ch==' '||ch=='
    '||ch=='
    '||ch=='	';}
        inline void read(int &x){
            bool sign=0; char ch=nc(); x=0;
            for (;blank(ch);ch=nc());
            if (IOerror)return;
            if (ch=='-')sign=1,ch=nc();
            for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
            if (sign)x=-x;
        }
        inline void read(ll &x){
            bool sign=0; char ch=nc(); x=0;
            for (;blank(ch);ch=nc());
            if (IOerror)return;
            if (ch=='-')sign=1,ch=nc();
            for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
            if (sign)x=-x;
        }
        inline void read(double &x){
            bool sign=0; char ch=nc(); x=0;
            for (;blank(ch);ch=nc());
            if (IOerror)return;
            if (ch=='-')sign=1,ch=nc();
            for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
            if (ch=='.'){
                double tmp=1; ch=nc();
                for (;ch>='0'&&ch<='9';ch=nc())tmp/=10.0,x+=tmp*(ch-'0');
            }
            if (sign)x=-x;
        }
        inline void read(char *s){
            char ch=nc();
            for (;blank(ch);ch=nc());
            if (IOerror)return;
            for (;!blank(ch)&&!IOerror;ch=nc())*s++=ch;
            *s=0;
        }
        inline void read(char &c){
            for (c=nc();blank(c);c=nc());
            if (IOerror){c=-1;return;}
        }
        //getchar->read
        inline void read1(int &x){
            char ch;int bo=0;x=0;
            for (ch=getchar();ch<'0'||ch>'9';ch=getchar())if (ch=='-')bo=1;
            for (;ch>='0'&&ch<='9';x=x*10+ch-'0',ch=getchar());
            if (bo)x=-x;
        }
        inline void read1(ll &x){
            char ch;int bo=0;x=0;
            for (ch=getchar();ch<'0'||ch>'9';ch=getchar())if (ch=='-')bo=1;
            for (;ch>='0'&&ch<='9';x=x*10+ch-'0',ch=getchar());
            if (bo)x=-x;
        }
        inline void read1(double &x){
            char ch;int bo=0;x=0;
            for (ch=getchar();ch<'0'||ch>'9';ch=getchar())if (ch=='-')bo=1;
            for (;ch>='0'&&ch<='9';x=x*10+ch-'0',ch=getchar());
            if (ch=='.'){
                double tmp=1;
                for (ch=getchar();ch>='0'&&ch<='9';tmp/=10.0,x+=tmp*(ch-'0'),ch=getchar());
            }
            if (bo)x=-x;
        }
        inline void read1(char *s){
            char ch=getchar();
            for (;blank(ch);ch=getchar());
            for (;!blank(ch);ch=getchar())*s++=ch;
            *s=0;
        }
        inline void read1(char &c){for (c=getchar();blank(c);c=getchar());}
        //scanf->read
        inline void read2(int &x){scanf("%d",&x);}
        inline void read2(ll &x){
    #ifdef _WIN32
            scanf("%I64d",&x);
    #else
    #ifdef __linux
            scanf("%lld",&x);
    #else
            puts("error:can't recognize the system!");
    #endif
    #endif
        }
        inline void read2(double &x){scanf("%lf",&x);}
        inline void read2(char *s){scanf("%s",s);}
        inline void read2(char &c){scanf(" %c",&c);}
        inline void readln2(char *s){gets(s);}
        //fwrite->write
        struct Ostream_fwrite{
            char *buf,*p1,*pend;
            Ostream_fwrite(){buf=new char[BUF_SIZE];p1=buf;pend=buf+BUF_SIZE;}
            void out(char ch){
                if (p1==pend){
                    fwrite(buf,1,BUF_SIZE,stdout);p1=buf;
                }
                *p1++=ch;
            }
            void print(int x){
                static char s[15],*s1;s1=s;
                if (!x)*s1++='0';if (x<0)out('-'),x=-x;
                while(x)*s1++=x%10+'0',x/=10;
                while(s1--!=s)out(*s1);
            }
            void println(int x){
                static char s[15],*s1;s1=s;
                if (!x)*s1++='0';if (x<0)out('-'),x=-x;
                while(x)*s1++=x%10+'0',x/=10;
                while(s1--!=s)out(*s1); out('
    ');
            }
            void print(ll x){
                static char s[25],*s1;s1=s;
                if (!x)*s1++='0';if (x<0)out('-'),x=-x;
                while(x)*s1++=x%10+'0',x/=10;
                while(s1--!=s)out(*s1);
            }
            void println(ll x){
                static char s[25],*s1;s1=s;
                if (!x)*s1++='0';if (x<0)out('-'),x=-x;
                while(x)*s1++=x%10+'0',x/=10;
                while(s1--!=s)out(*s1); out('
    ');
            }
            void print(double x,int y){
                static ll mul[]={1,10,100,1000,10000,100000,1000000,10000000,100000000,
                                 1000000000,10000000000LL,100000000000LL,1000000000000LL,10000000000000LL,
                                 100000000000000LL,1000000000000000LL,10000000000000000LL,100000000000000000LL};
                if (x<-1e-12)out('-'),x=-x;x*=mul[y];
                ll x1=(ll)floor(x); if (x-floor(x)>=0.5)++x1;
                ll x2=x1/mul[y],x3=x1-x2*mul[y]; print(x2);
                if (y>0){out('.'); for (size_t i=1;i<y&&x3*mul[i]<mul[y];out('0'),++i); print(x3);}
            }
            void println(double x,int y){print(x,y);out('
    ');}
            void print(char *s){while (*s)out(*s++);}
            void println(char *s){while (*s)out(*s++);out('
    ');}
            void flush(){if (p1!=buf){fwrite(buf,1,p1-buf,stdout);p1=buf;}}
            ~Ostream_fwrite(){flush();}
        }Ostream;
        inline void print(int x){Ostream.print(x);}
        inline void println(int x){Ostream.println(x);}
        inline void print(char x){Ostream.out(x);}
        inline void println(char x){Ostream.out(x);Ostream.out('
    ');}
        inline void print(ll x){Ostream.print(x);}
        inline void println(ll x){Ostream.println(x);}
        inline void print(double x,int y){Ostream.print(x,y);}
        inline void println(double x,int y){Ostream.println(x,y);}
        inline void print(char *s){Ostream.print(s);}
        inline void println(char *s){Ostream.println(s);}
        inline void println(){Ostream.out('
    ');}
        inline void flush(){Ostream.flush();}
        //puts->write
        char Out[OUT_SIZE],*o=Out;
        inline void print1(int x){
            static char buf[15];
            char *p1=buf;if (!x)*p1++='0';if (x<0)*o++='-',x=-x;
            while(x)*p1++=x%10+'0',x/=10;
            while(p1--!=buf)*o++=*p1;
        }
        inline void println1(int x){print1(x);*o++='
    ';}
        inline void print1(ll x){
            static char buf[25];
            char *p1=buf;if (!x)*p1++='0';if (x<0)*o++='-',x=-x;
            while(x)*p1++=x%10+'0',x/=10;
            while(p1--!=buf)*o++=*p1;
        }
        inline void println1(ll x){print1(x);*o++='
    ';}
        inline void print1(char c){*o++=c;}
        inline void println1(char c){*o++=c;*o++='
    ';}
        inline void print1(char *s){while (*s)*o++=*s++;}
        inline void println1(char *s){print1(s);*o++='
    ';}
        inline void println1(){*o++='
    ';}
        inline void flush1(){if (o!=Out){if (*(o-1)=='
    ')*--o=0;puts(Out);}}
        struct puts_write{
            ~puts_write(){flush1();}
        }_puts;
        inline void print2(int x){printf("%d",x);}
        inline void println2(int x){printf("%d
    ",x);}
        inline void print2(char x){printf("%c",x);}
        inline void println2(char x){printf("%c
    ",x);}
        inline void print2(ll x){
    #ifdef _WIN32
            printf("%I64d",x);
    #else
    #ifdef __linux
            printf("%lld",x);
    #else
            puts("error:can't recognize the system!");
    #endif
    #endif
        }
        inline void println2(ll x){print2(x);printf("
    ");}
        inline void println2(){printf("
    ");}
    #undef ll
    #undef OUT_SIZE
    #undef BUF_SIZE
    };
    using namespace fastIO;
     
    int T, n, s[maxn], has[maxn], pn;
    int x[maxn], y[maxn], c[maxn];
    int sa[maxn], rk[maxn], height[maxn];
     
    void SA() {
        int m = pn;
        for (int i = 0; i <= m; ++i) c[i] = 0;
        for (int i = 1; i <= n; ++i) ++c[(x[i]=s[i])];
        for (int i = 1; i <= m; ++i) c[i] += c[i-1];
        for (int i = n; i >= 1; --i) sa[c[x[i]]--] = i;
     
        for (int p, k = 1; k <= n; k <<= 1) {
            p = 0;
            for (int i = n-k+1; i <= n; ++i) y[++p] = i ;
            for (int i = 1; i <= n; ++i) {
                if(sa[i] > k) y[++p] = sa[i] - k;
            }
     
            for (int i = 0; i <= m; ++i) c[i] = 0;
            for (int i = 1; i <= n; ++i) ++c[x[y[i]]];
            for (int i = 1; i <= m; ++i) c[i] += c[i-1];
            for (int i = n; i >= 1; --i) sa[c[x[y[i]]]--] = y[i];
     
            p = y[sa[1]] = 1;
            for (int i = 2; i <= n; ++i) {
                int a = sa[i]+k > n? -1 : x[sa[i]+k];
                int b = sa[i-1]+k > n ? -1: x[sa[i-1]+k];
                y[sa[i]] = (x[sa[i]] == x[sa[i-1]] && a == b) ? p : ++p;
            }
            swap(x, y);
            if(p >= n) break;
            m = p;
        }
        for (int i = 1; i <= n; ++i) rk[sa[i]] = i;
    }
     
    void getHeight() {
        for (int k = 0, i = 1; i <= n; ++i) {
            if(k) --k;
            int j = sa[rk[i]-1];
            while(s[i+k] == s[j+k]) ++k;
            height[rk[i]] = k;
        }
    }
     
    void build() {
        SA();
        getHeight();
    }
     
    int sta[maxn], top;
    int nxt[maxn], st[maxn][30];
    int pw2[maxn];
    ll suf[maxn];
     
    int query(int l, int r) {
        int len = r - l + 1;
        int d = pw2[len]-1;
    //    while((1<<d+1) <= len) ++d;
        int p = 1<<d;
        if(s[st[l][d]] > s[st[r-p+1][d]]) return st[l][d];
        return st[r-p+1][d];
    }
     
    int main() {
        for (int i = 1; i < maxn; i <<= 1) pw2[i] = 1;
        for (int i = 1; i < maxn; ++i) pw2[i] += pw2[i-1];
        read(T);
        while(T--) {
            read(n);
            for (int i = 1; i <= n; ++i) {
                read(s[i]);
                has[i] = s[i];
            }
     
            sort(has+1, has+1+n);
            pn = unique(has+1, has+1+n) - has - 1;
            for (int i = 1; i <= n; ++i) {
                s[i] = lower_bound(has+1, has+1+pn, s[i]) - has;
                st[i][0] = i;
            }
     
            s[n+1] = inf, sta[(top = 1)] = n+1;
            for (int i = n; i >= 1; --i) {
                while(top && s[sta[top]] <= s[i]) --top;
                nxt[i] = sta[top];
                sta[++top] = i;
            }
            nxt[n+1] = st[n + 1][0] = n + 1;
     
            for (int j = 1; j <= 20; ++j) {
                int p = 1<<j-1, l = (1<<j)-1;
                for (int i = 1; i + l <= n; ++i) {
                    if(s[st[i][j - 1]] > s[st[i + p][j - 1]]) st[i][j] = st[i][j - 1];
                    else st[i][j] = st[i + p][j - 1];
                }
            }
     
            suf[n+1] = 0;
            for (int i = n; i >= 1; --i) suf[i] = 1ll * has[s[i]] * (nxt[i] - i);
            for (int i = n; i >= 1; --i) suf[i] += suf[nxt[i]];
     
            build();
     
            ll ans = 0;
            for (int i = 1; i <= n; ++i) {
                int h = height[i];
                if(h == 0) {
                    ans += suf[sa[i]];
                } else {
                    int r = sa[i] + h - 1;
                    int l = sa[i];
                    int pos = query(l, r);
                    ans += suf[nxt[pos]] + 1ll * has[s[pos]] * (nxt[pos] - r - 1);
                }
            }
            println(ans);
    //        printf("%lld
    ", ans);
        }
        return 0;
    }
    
  • 相关阅读:
    dubbo的超时机制
    今天又遇到之前的问题,后端返回数据long到前端失真
    如何在一台机子上配置两个github
    当sum函数返回null时处理
    Linux中zookeeper安装
    Linux常用指令
    sql执行顺序
    固定时间刷新某个固定值 java
    docker基础
    python之CSS
  • 原文地址:https://www.cnblogs.com/acerkoo/p/11655511.html
Copyright © 2011-2022 走看看