zoukankan      html  css  js  c++  java
  • 莫队算法总结

    莫队算法 总结

    最近两天学习了一下莫队,感觉莫队算法还是挺好用的(现在看到离线询问就想莫队...
    就稍微写一下总结吧,加深一下对算法的理解。

    • 普通莫队

    核心思想:莫队算法一般用来离线处理一系列无修改的区间询问问题,通过将所有的询问保存下来,并且将所有的询问区间进行适当地排序,从而达到降低时间复杂度的效果。

    对于所有的询问区间([l_i,r_i]),如果暴力地进行区间端点移动,那么对于一次询问,区间端点可能移动(n)的长度。假设询问的规模与(n)同级,那么复杂度就为(O(n^2))

    但其实,我们可以巧妙地安排区间顺序以降低时间复杂度。
    莫队算法的思想如下:
    将区间分为(sqrt{n})块,每块的长度也为(sqrt{n}),之后对所有的询问区间排序,如果区间左端点在同一块内,则按右端点排序;否则则按左端点所在块进行排序。
    就这样排序过后,暴力计算就行了,可以证明,时间复杂度为(O(n^{frac{3}{2}}))

    下面给出简单的证明:
    假设区间左端点在同一块内,那么一次询问左端点最多移动(sqrt{n}),由于右端点是单增的,则右端点移动总的复杂度为(O(n)),此时端点移动的总复杂度为(O(n^{frac{3}{2}}))。(注意这里是均摊意义上的复杂度)
    如果区间左端点不在同一块,也就是左端点跨块移动,因为一共有(sqrt{n})块,每次右端点的移动最多(O(n)),此时总的时间复杂度也为(O(n^{frac{3}{2}}))
    所以经过分块过后,时间复杂度可以降为(O(n^frac{3}{2}))

    可以先通过几道例题感受一下:
    洛谷P1494 小Z的袜子

    (cnt[i])为第(i)种颜色的袜子的个数,当前区间为([l,r]),那么容易知道所求的答案为(frac{sum_{i=1}^{k}{C_{cnt[i]}^2}}{C_{r-l+1}^{2}})
    因为分母是与区间长度有关,我们只用考虑区间端点变化时,分子的变化情况就行了。

    先单独把分子拿出来:(sum_{i=1}^{k}{C_{cnt[i]}^2}),当区间范围增加一时,会存在一个(t),有(cnt[t]+1),在这个求和式中,其余项不会改变,那么我们就只用看这一项对答案的影响。
    影响即为:(C_{cnt[t]+1}^2-C_{cnt[t]}^2),那么在进行答案更新时算一下这个式子就好了。
    对于区间范围减小的情况分析也同理。
    代码如下:

    Code
    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N = 50005;
    int n, m, block;
    int a[N];
    struct query{
    	int l, r, id ;
    }Q[N];
    struct Ans{
    	ll p, q;
    }answer[N];
    bool cmp(query x, query y) {
    	if((x.l - 1) / block + 1 == (y.l - 1) / block + 1) return x.r < y.r;
    	return (x.l - 1) / block + 1 < (y.l - 1) / block + 1 ;
    }
    ll gcd(ll A, ll B) {
    	return B == 0 ? A : gcd(B, A % B) ;
    }
    ll ans ;
    ll cnt[N] ;
    void update(int pos, int sign) {
    	ans -= cnt[a[pos]] * cnt[a[pos]] ;
    	cnt[a[pos]] += sign ;
    	ans += cnt[a[pos]] * cnt[a[pos]] ;
    }
    int main() {
    	scanf("%d%d",&n, &m) ;
    	block = (int)sqrt(n) ;
    	for(int i = 1; i <= n; i++) scanf("%d", &a[i]) ;
    	for(int i = 1; i <= m; i++) {
    		scanf("%d%d",&Q[i].l, &Q[i].r) ;
    		Q[i].id = i ;
    	}
    	sort(Q + 1, Q + m + 1, cmp) ;
    	int l = 1, r = 0;
    	for(int i = 1; i <= m; i++) {
    		for(; r < Q[i].r; r++) update(r + 1, 1) ;
    		for(; r > Q[i].r; r--) update(r, -1) ;
    		for(; l < Q[i].l; l++) update(l, -1) ;
    		for(; l > Q[i].l; l--) update(l - 1, 1) ;
    		answer[Q[i].id].p = ans - Q[i].r + Q[i].l - 1;
    		answer[Q[i].id].q = 1ll * (Q[i].r - Q[i].l + 1) * (Q[i].r - Q[i].l) ;
    		if(Q[i].l == Q[i].r) answer[Q[i].id].p = 0, answer[Q[i].id].q = 1;				
    		ll g = gcd(answer[Q[i].id].p, answer[Q[i].id].q) ;
    		answer[Q[i].id].p /= g; answer[Q[i].id].q /= g;
    	}
    	for(int i = 1; i <= m; i++) printf("%lld/%lld
    ",answer[i].p, answer[i].q) ;
    	return 0;
    }
    

    洛谷P3245 大数

    给出的数字串挺长的,但是质数(p)不是很大。
    我们知道,如果一个数字(t)(p)的倍数,那么就有(tmod p=0)。但是区间中的子串很多,我们直接时间复杂度等同于暴力。所以我们可以考虑将问题转化一下。

    设串(s)所在区间为([l,r]),串的长度为(n),那么我们知道(s*10^{r-l+1}=t[l,l+1,cdots,n]-t[r+1,r+2,cdots,n]*10^{r-l+1})
    所以当质数(p)不为2和5时,(smod p=0 => (t[l,l+1,cdots,n]-t[r+1,r+2,cdots,n]*10^{r-l+1}) mod p=0 => t[l,l+1,cdots,n]mod p=t[r+1,r+2,cdots,n]mod p)
    所以我们就可以维护一个数组(f[i]),表示后缀(i)(p)取余的值为多少,那么我们就可以将一个区间为([l,r])的询问转化为([l,r+1])中有多少对(f)相等了。
    之后就用莫队来搞,计算区间范围增加或者减小对答案的影响就好了。思路同上一题类似。

    对于(p)为2或者5的情况,特判一波,维护前缀个数就好了。
    代码如下:

    Code
    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N = 2e5 + 5;
    ll p, cnt;
    int block, n;
    char s[N] ;
    ll f[N], d[N];
    ll num[N] ;
    ll sum[N][3] ;
    struct Query{
    	int l, r, id;
    	ll ans ;
    }q[N];
    int Q;
    bool cmp(Query x, Query y) {
    	if((x.l - 1) / block + 1 == (y.l - 1) / block + 1) return x.r < y.r;
    	return (x.l - 1) / block + 1 < (y.l - 1) / block + 1 ;
    }
    bool cmp2(Query x, Query y) {
    	return x.id < y.id ;
    }
    void sol1() {
    	for(int i = 1; i <= n; i++) {
    		sum[i][0] = sum[i - 1][0] ;
    		sum[i][1] = sum[i - 1][1] ;
    		sum[i][2] = sum[i - 1][2] ;
    		if(p == 2 && (s[i] - '0') % p == 0) sum[i][0] += i, sum[i][2]++;
    		if(p == 5 && (s[i] - '0') % p == 0) sum[i][1] += i, sum[i][2]++;
    	}
    	for(int i = 1; i <= Q; i++) {
    		int l = q[i].l, r = q[i].r;
    		ll k;
    		if(p == 2) k = sum[r][0] - sum[l - 1][0] ;
    		else k = sum[r][1] - sum[l - 1][1] ;
    		q[i].ans = k - (sum[r][2] - sum[l - 1][2]) * (l - 1) ;
    	}
    }
    void update2(int pos, int sign) {
    	cnt -= (num[f[pos]] - 1) * num[f[pos]] / 2;
    	num[f[pos]] += sign;
    	cnt += (num[f[pos]] - 1) * num[f[pos]] / 2;
    }
    void sol2() {
    	int l = 1, r = 0 ;
    	for(int i = 1; i <= Q; i++) {
    		q[i].r += 1;
    		for(; r < q[i].r; r++) update2(r + 1, 1) ;
    		for(; r > q[i].r; r--) update2(r, -1) ;
    		for(; l > q[i].l; l--) update2(l - 1, 1) ;
    		for(; l < q[i].l; l++) update2(l, -1) ;
    		q[i].ans = cnt ;		
    	}
    }
    int main() {
    	scanf("%lld%s%d", &p, s + 1, &Q) ;
    	n = strlen(s + 1) ;
    	block = (int)sqrt(n) ;
    	for(int i = 1; i <= Q; i++) {
    		scanf("%d%d",&q[i].l, &q[i].r) ;
    		q[i].id = i;
    	}
    	sort(q + 1, q + Q + 1, cmp) ;
    	ll x = 0, qp = 1;
    	int flag = -1;
    	for(int i = n; i >= 1; i--) {
    		x = (x + (s[i] - '0') * qp % p) % p;
    		d[i] = f[i] = x;
    		if(f[i] == 0) flag = i;
    		qp = qp * 10 % p;
    	}
    	sort(d + 1, d + n + 1) ;
    	int D = unique(d + 1, d + n + 1) - d - 1;
    	for(int i = 1; i <= n; i++) f[i] = lower_bound(d + 1, d + n + 1, f[i]) - d;
    	if(flag > 0) f[n + 1] = f[flag] ;	
    	if(p == 2 || p == 5) sol1() ;
    	else sol2() ;
    	sort(q + 1, q + Q + 1, cmp2) ;
    	for(int i = 1; i <= Q; i++) printf("%lld
    ", q[i].ans) ; 
    	return 0;
    }
    

    P3246 序列

    感觉这个题挺好的。没想到还可以用莫队来搞。
    对于区间([l,r]),假设我们要将(r)增加1,那么就会多出(r-l+2)个序列,我们就分析他们对答案的影响。
    假设区间([l,r])中最小值所在位置为(p),那么很显然,左端点在([l,l+1,cdots,p])时,区间最小值就为(a[p])

    对于(r+1)而言,如果我们找到左边第一个比他小的位置为(k),那么此时对答案的贡献就为((r-k+2)*a[k]);同理对(k)也可以执行同样的操作。最后必然会存在一个位置(q),其左边第一个比他小的位置为(q),那么操作在这里就终止了。

    每次这么操作时间复杂度过高,发现可以维护一个类似于前缀和一样的东西,递推地来维护就行了。设该前缀和函数为(f),那么区间右端点增加一位对答案的贡献为:(a[p]*(p-l+1)+f[r+1]-f[p])
    这样就可以O(1)算出对答案的影响了。
    左端点的情况也类似考虑。
    代码如下:

    Code
    #include <bits/stdc++.h>
    #define INF 0x3f3f3f3f
    using namespace std;
    typedef long long ll;
    const int N = 2e5 + 5;
    int n, m, block;
    int a[N];
    struct Query{
        int l, r, id;
        ll ans;
    }q[N];
    bool cmp(Query A, Query B) {
        if((A.l - 1) / block  + 1 == (B.l - 1) / block + 1) return A.r < B.r;
        return (A.l - 1) / block + 1< (B.l - 1) / block + 1;
    }
    bool cmp_id(Query A, Query B) {
        return A.id < B.id ;
    }
    int l[N], r[N] ;
    int sta[N], top;
    ll f12[N], f21[N];
    int f[N][22], pos[N][22], Log2[N];
    ll ans ;
    int Get_min(int L, int R) {
        ll k = Log2[R - L + 1];
        if(f[L][k] > f[R - (1LL << k) + 1][k]) return pos[R - (1LL << k) + 1][k] ;
        return pos[L][k] ;
    }
    void update1(int pos, int L, int R, int sign) {
        int p = Get_min(L, R) ;
        ll sum = f12[R] - f12[p] + 1ll * (p - L + 1) * a[p];
        ans += 1ll * sign * sum;
    }
    void update2(int pos, int L, int R, int sign) {
        int p = Get_min(L, R) ;
        ll sum = f21[L] - f21[p] + 1ll * (R - p + 1) * a[p] ;
        ans += 1ll * sign * sum ;
    }
    int main() {
        ios::sync_with_stdio(false); cin.tie(0);
        cin >> n >> m ;
        Log2[1] = 0;
        for(int i = 2; i <= n; i++) Log2[i] = Log2[i >> 1] + 1;
        block = sqrt(n) ;
        memset(f, INF, sizeof(f)) ;
        for(int i = 1; i <= n; i++) {
            cin >> a[i] ;
            f[i][0] = a[i] ;
            pos[i][0] = i ;
        }
        for(int j = 1; j <= 17; j++) {
            for(int i = 1; i + (1 << (j - 1)) <= n; i++) {
                if(f[i][j - 1] > f[i + (1 << (j - 1))][j - 1]) {
                    f[i][j] = f[i + (1 << (j - 1))][j - 1] ;
                    pos[i][j] = pos[i + (1 << (j - 1))][j - 1] ;
                } else {
                    f[i][j] = f[i][j - 1];
                    pos[i][j] = pos[i][j - 1] ;
                }
            }
        }
        for(int i = 1; i <= n + 1; i++) {
            while(top > 0 && a[sta[top]] >= a[i]) r[sta[top--]] = i ;
            sta[++top] = i;
        }
        top = 0;
        for(int i = n; i >= 0; i--) {
            while(top > 0 && a[sta[top]] >= a[i]) l[sta[top--]] = i;
            sta[++top] = i;
        }
        for(int i = 1; i <= n; i++)
            f12[i] = f12[l[i]] + 1ll * (i - l[i]) * a[i] ;
        for(int i = n; i >= 1; i--)
            f21[i] = f21[r[i]] + 1ll * (r[i] - i) * a[i] ;
        for(int i = 1; i <= m; i++) {
            int L, R;
            cin >> L >> R;
            q[i].l = L; q[i].r = R;
            q[i].id = i;
        }
        sort(q + 1, q + m + 1, cmp) ;
        int L = 1, R = 0;
        for(int i = 1; i <= m; i++) {
            for(; R < q[i].r; R++) update1(R + 1, L, R + 1, 1) ;
            for(; R > q[i].r; R--) update1(R, L, R, -1) ;
            for(; L < q[i].l; L++) update2(L, L, R, -1) ;
            for(; L > q[i].l; L--) update2(L - 1, L - 1, R, 1) ;
            q[i].ans = ans ;
        }
        sort(q + 1, q + m + 1, cmp_id) ;
        for(int i = 1; i <= m; i++)
            cout << q[i].ans << '
    ' ;
        return 0;
    }
    
    • 带修改莫队

    之前说的莫队是不支持修改的,但其实也可以支持修改,只需要再加一维“时间状态”就行了,对于每个询问,新增一维,变为([l,r,k]),表示当前区间为([l,r]),之前经过(k)次修改操作的询问。
    为什么这样是正确的呢?
    因为我们如果知道了([l,r,k])的答案,那么就很容易知道([l+1,r,k],[l-1,r,k],[l,r-1,k],[l,r+1,k],[l,r,k-1],[l,r,k+1])对答案的影响。
    具体来说,修改时间维度时,看看修改的位置是否在([l,r])中,如果在则会对答案产生影响,否则直接修改就是了。之后区间端点左右移动时,遇到的位置也一定是完成(k)次修改过后的值了。

    此时我们还是将区间进行分块,但现在要分为(n^frac{2}{3})块,每块长度为(n^frac{1}{3})。然后以左端点所在的块为第一关键字,右端点所在的块为第二关键字,修改次数为第三关键字进行排序。

    可以证明这样的时间复杂度是(O(n^frac{5}{3}))的。
    证明方法就类似于上面的分析。

    来看一道例题:
     

    [洛谷P1903 数颜色 / 维护队列](https://www.luogu.org/problemnew/show/P1903)

    这就是个待修改莫队的模板题,多了一维对时间的修改,详细见代码吧:

    Code
    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N = 50005, MAX = 1e6 + 5;
    int n, m, block, num, M, l, r, t, ans;
    char ss[5];
    int c[N], cnt[MAX], last[N];
    struct Upd{
        int pos, v;
    }upd[N];
    struct query{
        int l, r, ans, id, k;
    }q[N];
    bool cmp(query a, query b) {
        if((a.l - 1) / block == (b.l - 1) / block && (a.r - 1) / block == (b.r - 1) / block) return a.k < b.k;
        else if((a.l - 1) / block == (b.l - 1) / block) return (a.r - 1) / block < (b.r - 1) / block ;
        return (a.l - 1) / block < (b.l - 1) / block;
    }
    bool cmp_id(query a, query b) {
        return a.id < b.id;
    }
    void update_add(int T) {
        int pos = upd[T].pos, v = upd[T].v;
        last[T] = c[pos] ;
        if(l <= pos && pos <= r) {
            cnt[c[pos]]--;
            if(cnt[c[pos]] == 0) ans--;
            cnt[v]++;
            if(cnt[v] == 1) ans++;
        }
        c[pos] = v;
    }
    void update_del(int T) {
        int pos = upd[T].pos, v = upd[T].v;
        if(l <= pos && pos <= r) {
            cnt[v]--;
            if(cnt[v] == 0) ans--;
            c[pos] = last[T] ;
            cnt[c[pos]]++;
            if(cnt[c[pos]] == 1) ans++;
        } else c[pos] = last[T] ;
    }
    void update(int pos, int val) {
        cnt[c[pos]] += val;
        if(val == 1) {
            if(cnt[c[pos]] == 1) ans++;
        } else if(val == -1)
            if(cnt[c[pos]] == 0) ans--;
    }
    int main() {
        scanf("%d%d",&n, &m) ;
        block = pow(n, 0.666666) ;
        for(int i = 1; i <= n; i++) scanf("%d", &c[i]) ;
        for(int i = 1; i <= m; i++) {
            scanf("%s",ss) ;
            if(ss[0] == 'R') {
                int pos, v;
                scanf("%d%d",&pos, &v) ;
                upd[++num].pos = pos; upd[num].v = v ;
            } else {
                int l, r;
                scanf("%d%d",&l, &r) ;
                q[++M].l = l; q[M].r = r;
                q[M].id = M; q[M].k = num;
            }
        }
        sort(q + 1, q + M + 1, cmp) ;
        l = 1, r = 0, t = 0;
        for(int i = 1; i <= M; i++) {
            for(; t < q[i].k; t++) update_add(t + 1) ;
            for(; t > q[i].k; t--) update_del(t) ;
            for(; r < q[i].r; r++) update(r + 1, 1) ;
            for(; r > q[i].r; r--) update(r, -1) ;
            for(; l < q[i].l; l++) update(l, -1) ;
            for(; l > q[i].l; l--) update(l - 1, 1) ;
            q[i].ans = ans ;
        }
        sort(q + 1, q + M + 1, cmp_id) ;
        for(int i = 1; i <= M; i++) printf("%d
    ", q[i].ans) ;
        return 0 ;
    }
    
    • 树上莫队

    如果可以对树进行分块的话,那么也可以对树上的询问用莫队来搞。刚好有一道树上分块的模板题
    那么树上莫队的具体做法就为,首先将树进行分块,然后对所有的询问([x,y]),首先让(x)的时间戳小于(y)的时间戳,然后就按照(x)所在的块为第一关键字,以y的时间戳为第二关键字进行排序就好了。
    之后考虑询问间的转移,方法为直接将(x_i->x_{i+1})路径上面的所有点除开它们lca的状态取反,同理也将(y_i->y_{i+1})路径上面的所有点除开它们lca的状态取反,计算答案就是了。
    具体证明直接引用vfk的博客:

    用S(v, u)代表 v到u的路径上的结点的集合。
    用root来代表根结点,用lca(v, u)来代表v、u的最近公共祖先。
    那么
    S(v, u) = S(root, v) xor S(root, u) xor lca(v, u)
    其中xor是集合的对称差。
    简单来说就是节点出现两次消掉。
    lca很讨厌,于是再定义
    T(v, u) = S(root, v) xor S(root, u)
    观察将curV移动到targetV前后T(curV, curU)变化:
    T(curV, curU) = S(root, curV) xor S(root, curU)
    T(targetV, curU) = S(root, targetV) xor S(root, curU)
    取对称差:
    T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
    由于对称差的交换律、结合律:
    T(curV, curU) xor T(targetV, curU)= S(root, curV) xorS(root, targetV)
    两边同时xor T(curV, curU):
    T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
    发现最后两项很爽……哇哈哈
    T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
    (有公式恐惧症的不要走啊 T_T)
    也就是说,更新的时候,xor T(curV, targetV)就行了。
    即,对curV到targetV路径(除开lca(curV, targetV))上的结点,将它们的存在性取反即可。

    因为lca我们不会算,所以最后单独考虑一下lca就行了。

    这是树上莫队的第一种解法,另外还有一种就是直接将树转化为dfs序,压缩成线性的,同时每个结点维护两个时间戳,一个是进去的时间戳,一个是出来的时间戳。
    那么对于树上的路径比如从(x)(y),若(LCA(x,y))为其中之一,那么两个的路径在dfs序中的体现就为(in[x]->in[y]);否则就为(out[x]->in[y])

    这样写的话也需要一个数组来记录当前结点是否被算入答案中,每到一个位置也要将相应的状态取反。这里注意第二种情况lca也不会算上,所以也要单独考虑一下lca。

    既然有了树上莫队,也有树上带修改莫队,好吧,其实原理都是差不多的。

    看一个例题:
     

    [P4074 糖果公园](https://www.luogu.org/problemnew/show/P4074)

    这基本上就是莫队算法的集大成者了。对答案的影响很好计算,维护一种颜色出现的次数就行了。
    主要就是代码,我写了两种,一种是dfs序的,一种是树上分块的。

    dfs序
    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N = 1e5 + 5;
    int n, m, qq, block;
    ll w[N], c[N], in[N], out[N], v[N];
    vector <int> g[N] ;
    struct Query{
        int l, r, id, k;
        ll ans ;
    }q[N];
    struct Upd{
        int x, y, last;
    }upd[N];
    bool cmp_id(Query A, Query B) {
        return A.id < B.id ;
    }
    bool cmp(Query A, Query B) {
        if((A.l - 1) / block == (B.l - 1) / block && (A.r - 1) / block == (B.r - 1) / block) return A.k < B.k;
        if((A.l - 1) / block == (B.l - 1) / block) return (A.r - 1) / block < (B.r - 1) / block;
        return (A.l - 1) / block < (B.l - 1) / block ;
    }
    int dfn;
    ll a[2 * N], f[N][22], deep[N], pre[N];
    void dfs(int u, int fa) {
        in[u] = ++dfn;
        a[dfn] = u ;
        deep[u] = deep[fa] + 1;
        for(auto v : g[u]) {
            if(v == fa) continue ;
            f[v][0] = u;
            for(int i = 1; i <= 17; i++) f[v][i] = f[f[v][i - 1]][i - 1] ;
            dfs(v, u);
        }
        out[u] = ++dfn;
        a[dfn] = u;
    }
    int LCA(int x, int y) {
        if(deep[x] < deep[y]) swap(x, y) ;
        for(int i = 17; i >= 0; i--)
            if(deep[f[x][i]] >= deep[y]) x = f[x][i] ;
        if(x == y) return x;
        for(int i = 17; i >= 0; i--)
            if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i] ;
        return f[x][0] ;
    }
    ll ans ;
    int l, r, t, qnum, num;
    bool vis[2 * N];
    ll cnt[N] ;
    void update(int u) {
        int col = c[u] ;
        if(vis[u]) ans -= 1ll * w[cnt[col]--] * v[col] ;
        else ans += 1ll * w[++cnt[col]] * v[col] ;
        vis[u] ^= 1;
    }
    void update_t(int T, int sign) {
        int u = upd[T].x, col = upd[T].y;
        if(sign == -1) col = upd[T].last;
        if(vis[u]) {
            update(u);
            c[u] = col;
            update(u);
        } else c[u] = col;
    }
    int main() {
        ios::sync_with_stdio(false); cin.tie(0);
        cin >> n >> m >> qq;
        block = pow(n, 0.666666) ;
        for(int i = 1; i <= m; i++) cin >> v[i] ;
        for(int i = 1; i <= n; i++) cin >> w[i] ;
        for(int i = 1; i < n; i++) {
            int u, v;
            cin >> u >> v;
            g[u].push_back(v) ;
            g[v].push_back(u) ;
        }
        dfs(1, 0);
        for(int i = 1; i <= n; i++) cin >> c[i], pre[i] = c[i];
        for(int i = 1; i <= qq; i++) {
            int op, x, y;
            cin >> op >> x >> y;
            if(op == 1) {
                if(in[x] > in[y]) swap(x, y) ;
                int lca = LCA(x, y) ;
                q[++num].r = in[y];
                q[num].k = qnum;
                q[num].id = num;
                if(lca == x) q[num].l = in[x] ;
                else q[num].l = out[x];
            } else {
                upd[++qnum].x = x;
                upd[qnum].y = y;
                //pre[qnum] = (qnum == 1 ? c[x] : upd[qnum - 1].y) ;
                upd[qnum].last = pre[x];
                pre[x] = y;
                
            }
        }
        sort(q + 1, q + num + 1, cmp) ;
        l = 1, r = 0, t = 0;
        for(int i = 1; i <= num; i++) {
            for(; t < q[i].k; t++) update_t(t + 1, 1) ;
            for(; t > q[i].k; t--) update_t(t, -1) ;
            for(; r < q[i].r; r++) update(a[r + 1]) ;
            for(; r > q[i].r; r--) update(a[r]) ;
            for(; l < q[i].l; l++) update(a[l]) ;
            for(; l > q[i].l; l--) update(a[l - 1]) ;
            int lca = LCA(a[l], a[r]) ;
            if(lca != a[l] && lca != a[r]) {
                update(lca) ;
                q[i].ans = ans ;
                update(lca) ;
            } else q[i].ans = ans ;
        }
        sort(q + 1, q + num + 1, cmp_id) ;
        for(int i = 1; i <= num; i++)
            cout << q[i].ans << '
    ' ;
        return 0;
    }
    
     
    树上分块
    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N = 1e5 + 5;
    int n, m, qq, block;
    int w[N], c[N], in[N], v[N];
    int dfn;
    int f[N][22], deep[N], pre[N];
    int sta[N], bel[N];
    int top, tot;
    vector <int> g[N] ;
    struct Query{
        int l, r, id, k;
        ll ans ;
    }q[N];
    struct Upd{
        int x, y, last;
    }upd[N];
    bool cmp_id(Query A, Query B) {
        return A.id < B.id ;
    }
    bool cmp(Query A, Query B) {
        if(bel[A.l] == bel[B.l] && bel[A.r] == bel[B.r]) return A.k < B.k;
        if(bel[A.l] == bel[B.l]) return bel[A.r] < bel[B.r] ;
        return bel[A.l] < bel[B.l] ;
    }
    void dfs(int u, int fa) {
        in[u] = ++dfn;
        deep[u] = deep[fa] + 1;
        int tmp = top ;
        for(auto v : g[u]) {
            if(v == fa) continue ;
            f[v][0] = u;
            for(int i = 1; i <= 16; i++) f[v][i] = f[f[v][i - 1]][i - 1] ;
            dfs(v, u);
            if(top - tmp >= block) {
                tot++;
                while(top > tmp) bel[sta[top--]] = tot;
            }
        }
        sta[++top] = u ;
    }
    int LCA(int x, int y) {
        if(deep[x] < deep[y]) swap(x, y) ;
        for(int i = 16; i >= 0; i--)
            if(deep[f[x][i]] >= deep[y]) x = f[x][i] ;
        if(x == y) return x;
        for(int i = 16; i >= 0; i--)
            if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i] ;
        return f[x][0] ;
    }
    ll ans ;
    int l, r, t, qnum, num;
    bool vis[N];
    ll cnt[N] ;
    void modify(int x) {
        int col = c[x] ;
        if(vis[x]) ans -= 1ll * w[cnt[col]--] * v[col] ;
        else ans += 1ll * w[++cnt[col]] * v[col] ;
        vis[x] ^= 1;
    }
    void update(int x, int y) {
        while(x != y) {
            if(deep[x] >= deep[y]) modify(x), x = f[x][0] ;
            else modify(y), y = f[y][0] ;
        }
    }
    void change(int x, int col) {
        if(vis[x]) {
            modify(x) ;
            c[x] = col ;
            modify(x) ;
        } else c[x] = col ;
    }
    int main() {
        ios::sync_with_stdio(false); cin.tie(0);
        cin >> n >> m >> qq;
        block = pow(n, 0.666666) ;
        for(int i = 1; i <= m; i++) cin >> v[i] ;
        for(int i = 1; i <= n; i++) cin >> w[i] ;
        for(int i = 1; i < n; i++) {
            int u, v;
            cin >> u >> v;
            g[u].push_back(v) ;
            g[v].push_back(u) ;
        }
        dfs(1, 0);
        for(int i = 1; i <= n; i++) cin >> c[i], pre[i] = c[i] ;
        for(int i = 1; i <= qq; i++) {
            int op, x, y;
            cin >> op >> x >> y;
            if(op == 1) {
                if(in[x] > in[y]) swap(x, y) ;
                q[++num].id = num;q[num].l = x;
                q[num].r = y;q[num].k = qnum;
            } else {
                upd[++qnum].x = x;upd[qnum].y = y;
                upd[qnum].last = pre[x] ;
                pre[x] = upd[qnum].y ;
            }
        }
        sort(q + 1, q + num + 1, cmp) ;
        l = q[1].l, r = q[1].r, t = 0;
        update(l, r);
        for(int i = 1; i <= num; i++) {
            for(;t < q[i].k; t++) change(upd[t + 1].x, upd[t + 1].y) ;
            for(;t > q[i].k; t--) change(upd[t].x, upd[t].last) ;
            update(l, q[i].l) ;
            update(r, q[i].r) ;
            int lca = LCA(q[i].l, q[i].r) ;
            modify(lca) ;
            q[q[i].id].ans = ans ;
            modify(lca) ;
            l = q[i].l, r = q[i].r ;
        }
        for(int i = 1; i <= num; i++) cout << q[i].ans << '
    ' ;
        return 0;
    }
    

     
    再看看这个题:
    CF375D Tree and Queries

    这里询问的是出现次数大于等于k的颜色有多少种,看似比较棘手。实际上我们维护一个数组(sum[i]),表示大于等于(i)的颜色有多少种就行了。这个稍微想想还是比较清楚的。
    代码如下:

    Code
    #include <bits/stdc++.h>
    using namespace std;
    typedef long long ll;
    const int N = 2e5 + 5;
    int n, m, block;
    int c[2 * N], a[2 * N], cnt[N];
    int ans ;
    vector <int> g[N];
    struct Query{
        int l, r, k, id, ans;
    }q[N];
    bool cmp(Query A, Query B) {
        if((A.l - 1) / block == (B.l - 1) / block) return A.r < B.r;
        return (A.l - 1) / block < (B.l - 1) / block ;
    }
    bool cmp_id(Query A, Query B) {
        return A.id < B.id ;
    }
    int in[N], out[N] ;
    int dfn, tot;
    bool vis[2 * N], has[2 * N];
    int sum[N] ;
    void update(int pos, int val) {
        int col = c[a[pos]] ;
        if(val == 1) {
            if(vis[a[pos]]) return ;
            vis[a[pos]] = 1;
            sum[++cnt[col]]++;
        } else {
            if(!vis[a[pos]]) return ;
            vis[a[pos]] = 0;
            sum[cnt[col]--]--;
        }
    }
    void dfs(int u, int fa) {
        in[u] = ++dfn;
        a[dfn] = u ;
        int t = dfn;
        for(auto v : g[u]) {
            if(v == fa) continue ;
            dfs(v, u) ;
        }
        out[u] = ++dfn;
        a[dfn] = u ;
    }
    int main() {
        ios::sync_with_stdio(false); cin.tie(0);
        cin >> n >> m;
        block = sqrt(n) ;
        for(int i = 1; i <= n; i++) cin >> c[i] ;
        for(int i = 1; i < n; i++) {
            int u, v;
            cin >> u >> v;
            g[u].push_back(v);
            g[v].push_back(u);
        }
        dfs(1, 0);
        for(int i = 1; i <= m; i++) {
            int v;
            cin >> v >> q[i].k;
            q[i].id = i;
            q[i].l = in[v] ; q[i].r = out[v] ;
        }
        sort(q + 1, q + m + 1, cmp);
        int l = 1, r = 0;
        for(int i = 1; i <= m; i++) {
            int k = q[i].k ;
            for(; r < q[i].r; r++) update(r + 1, 1) ;
            for(; r > q[i].r; r--) update(r, -1) ;
            for(; l < q[i].l; l++) update(l, -1) ;
            for(; l > q[i].l; l--) update(l - 1, 1) ;
            q[i].ans = sum[k] ;
        }
        sort(q + 1, q + m + 1, cmp_id) ;
        for(int i = 1; i <= m; i++)
            cout << q[i].ans << '
    ' ;
        return 0;
    }
    

     
    最后再来看一道例题:
    BZOJ3289:Mato的文件管理

    学过莫队之后是不是感觉很简单?
    每次区间转移用树状数组维护信息即可。
    代码如下:

    Code
    #include <cstdio>
    #include <cstring>
    #include <algorithm>
    #include <iostream>
    #include <cmath>
    using namespace std;
    typedef long long ll;
    const int N = 50005;
    int c[N], a[N], b[N];
    int l, r ;
    int n, block;
    struct Query{
        int l, r, id ;
        ll ans ;
    }q[N];
    bool cmp(Query A, Query B) {
        if((A.l - 1) / block == (B.l - 1) / block) return A.r < B.r;
        return (A.l - 1) / block < (B.l - 1) / block ;
    }
    int lowbit(int x) {
        return x & (-x) ;
    }
    void add(int x, int val) {
        for(int i = x; i < N; i += lowbit(i)) c[i] += val;
    }
    ll query(int x) {
        ll ans = 0;
        for(int i = x; i > 0; i -= lowbit(i)) ans += c[i];
        return ans ;
    }
    ll ans ;
    void update(int x, int v, int sign) {
        if(sign == 1) {
            if(v == 1) {
                add(a[x], 1) ;
                int sum = query(a[x]) ;
                ans += r - l + 2 - sum ;
            } else {
                int sum = query(a[x]) ;
                ans -= (r - l + 1 - sum) ;
                add(a[x], -1) ;
            }
        } else {
            if(v == 1) {
                int sum = query(a[x] - 1) ;
                ans += sum;
                add(a[x], 1) ;
            } else {
                add(a[x], -1) ;
                int sum = query(a[x] - 1) ;
                ans -= sum ;
            }
        }
    }
    int main() {
        ios::sync_with_stdio(false); cin.tie(0);
        cin >> n;
        block = sqrt(n) ;
        for(int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i];
        sort(b + 1, b + n + 1);
        int D = unique(b + 1, b + n + 1) - b - 1;
        for(int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + D + 1, a[i]) - b;
        int Q;
        cin >> Q;
        for(int i = 1; i <= Q; i++) {
            int l, r;
            cin >> l >> r;
            q[i].l = l; q[i].r = r; q[i].id = i;
        }
        sort(q + 1, q + Q + 1, cmp) ;
        l = 1, r = 0;
        for(int i = 1; i <= Q; i++) {
            for(; r < q[i].r; r++) update(r + 1, 1, 1) ;
            for(; r > q[i].r; r--) update(r, -1, 1) ;
            for(; l < q[i].l; l++) update(l, -1, -1) ;
            for(; l > q[i].l; l--) update(l - 1, 1, -1) ;
            q[q[i].id].ans = ans ;
        }
        for(int i = 1; i <= Q; i++)
            cout << q[i].ans << '
    ' ;
        return 0;
    }
    
  • 相关阅读:
    SQL Server系统表sysobjects介绍
    tofixed方法 四舍五入
    (function($){})(jQuery);
    DOS批处理命令-字符串操作
    IF ERRORLEVEL 和 IF %ERRORLEVEL% 区别
    Gpupdate命令详解
    DOS批处理中%cd%和%~dp0的区别
    SetACL 使用方法详细参数中文解析
    Lazarus 1.6 增加了新的窗体编辑器——Sparta_DockedFormEditor.ipk
    Lazarus 1.44升级到1.6 UTF8处理发生变化了
  • 原文地址:https://www.cnblogs.com/heyuhhh/p/10827143.html
Copyright © 2011-2022 走看看