zoukankan      html  css  js  c++  java
  • [做题记录-计数相关] [AGC023E] Inversions

    题意

    一个长度为(n)的排列数组, 每个位置有上限的限制, 求所有合法的排列的逆序数的和。

    (n leq 10^5)

    题解

    经典套路是先计数一下序列的个数然后考虑每一对数对答案的贡献。

    考虑从大往小填数, 记(b_i)表示(a_i ge i)的位置个数, 那么合法的排列个数为:

    [cnt = prod_i b_i - (n - i) ]

    然后现在对于序列上一对位置考虑。不妨设(i < j), (p_i)表示填好以后(i)位置上的数是什么。

    (a_i = a_j)时, 显然(p_i >p_j)(p_i < p_j)的情况数相同, 贡献是(frac{cnt}{2})

    (a_i < a_j)的时候, 讨论(a_j)的取值。当(a_j in [a_i + 1, a_j])的时候这里肯定是没有贡献的, 那么考虑强行让(a_j = a_i), 那么这样的话会发现(b_i : which i in [a_i + 1, a_j])会减少(1)。那么这里的贡献是:

    [frac{cnt}{2} prod _{k = a_i + 1}^{a_j}frac{b_k - (n - k) - 1}{b_k - (n - k)} ]

    (a_i > a_j)的时候, 不妨把(i, j)反过来, 讨论就变成了(a_j < a_i), 那么这里的顺序对就和上面情况中逆序对的情况是一样的, 用总方案减去出现顺序对的情况即可。

    [cnt - frac{cnt}{2} prod _{k = a_j + 1}^{a_i}frac{b_k - (n - k) - 1}{b_k - (n - k)} ]

    然后考虑快速计算, 以(a_i < a_j)为例, 考虑从小往大枚举(a_j), 维护一个全局的数据结构, 每次(a_j)增大的时候全局乘, 查询的时候查位置小于(j)的所有位置的权值和, 然后在(j)位置加入一个(frac{cnt}{2})

    /*
    	QiuQiu /qq
      ____    _           _                 __                
      / __   (_)         | |               / /                
     | |  | |  _   _   _  | |  _   _       / /    __ _    __ _ 
     | |  | | | | | | | | | | | | | |     / /    / _` |  / _` |
     | |__| | | | | |_| | | | | |_| |    / /    | (_| | | (_| |
      \___\_ |_|  \__,_| |_|  \__, |   /_/      \__, |  \__, |
                                __/ |               | |     | |
                               |___/                |_|     |_|
    */
    
    #include <bits/stdc++.h>
    
    using namespace std;
    
    class Input {
    	#define MX 1000000
    	private :
    		char buf[MX], *p1 = buf, *p2 = buf;
    		inline char gc() {
    			if(p1 == p2) p2 = (p1 = buf) + fread(buf, 1, MX, stdin);
    			return p1 == p2 ? EOF : *(p1 ++);
    		}
    	public :
    		Input() {
    			#ifdef Open_File
    				freopen("a.in", "r", stdin);
    				freopen("a.out", "w", stdout);
    			#endif
    		}
    		template <typename T>
    		inline Input& operator >>(T &x) {
    			x = 0; int f = 1; char a = gc();
    			for(; ! isdigit(a); a = gc()) if(a == '-') f = -1;
    			for(; isdigit(a); a = gc()) 
    				x = x * 10 + a - '0';
    			x *= f;
    			return *this;
    		}
    		inline Input& operator >>(char &ch) {
    			while(1) {
    				ch = gc();
    				if(ch != '
    ' && ch != ' ') return *this;
    			}
    		}
    		inline Input& operator >>(char *s) {
    			int p = 0;
    			while(1) {
    				s[p] = gc();
    				if(s[p] == '
    ' || s[p] == ' ' || s[p] == EOF) break;
    				p ++; 
    			}
    			s[p] = '';
    			return *this;
    		}
    	#undef MX
    } Fin;
    
    class Output {
    	#define MX 1000000
    	private :
    		char ouf[MX], *p1 = ouf, *p2 = ouf;
    		char Of[105], *o1 = Of, *o2 = Of;
    		void flush() { fwrite(ouf, 1, p2 - p1, stdout); p2 = p1; }
    		inline void pc(char ch) {
    			* (p2 ++) = ch;
    			if(p2 == p1 + MX) flush();
    		}
    	public :
    		template <typename T> 
    		inline Output& operator << (T n) {
    			if(n < 0) pc('-'), n = -n;
    			if(n == 0) pc('0');
    			while(n) *(o1 ++) = (n % 10) ^ 48, n /= 10;
    			while(o1 != o2) pc(* (--o1));
    			return *this; 
    		}
    		inline Output & operator << (char ch) {
    			pc(ch); return *this; 
    		}
    		inline Output & operator <<(const char *ch) {
    			const char *p = ch;
    			while( *p != '' ) pc(* p ++);
    			return * this;
    		}
    		~Output() { flush(); } 
    	#undef MX
    } Fout;
    
    #define cin Fin
    #define cout Fout
    #define endl '
    '
    
    using LL = long long;
    
    inline int log2(unsigned int x);
    inline int popcount(unsigned x);
    inline int popcount(unsigned long long x);
    
    template <int mod>
    class Int {
    	private :
    		inline int Mod(int x) { return x + ((x >> 31) & mod); } 
    		inline int power(int x, int k) {
    			int res = 1;
    			while(k) {
    				if(k & 1) res = 1LL * x * res % mod;
    				x = 1LL * x * x % mod; k >>= 1;
    			}
    			return res;
    		}
    	public :
    		int v;
    		Int(int _v = 0) : v(_v) {}
    		operator int() { return v; }
    		
    		inline Int operator =(Int x) { return Int(v = x.v); }
    		inline Int operator =(int x) { return Int(v = x); }
    		inline Int operator *(Int x) { return Int(1LL * v * x.v % mod); }
    		inline Int operator *(int x) { return Int(1LL * v * x % mod); }
    		inline Int operator +(Int x) { return Int( Mod(v + x.v - mod) ); }
    		inline Int operator +(int x) { return Int( Mod(v + x - mod) ); }
    		inline Int operator -(Int x) { return Int( Mod(v - x.v) ); }
    		inline Int operator -(int x) { return Int( Mod(v - x) ); }
    		inline Int operator ~() { return Int(power(v, mod - 2)); }
    		inline Int operator +=(Int x) { return Int(v = Mod(v + x.v - mod)); }
    		inline Int operator +=(int x) { return Int(v = Mod(v + x - mod)); }
    		inline Int operator -=(Int x) { return Int(v = Mod(v - x.v)); }
    		inline Int operator -=(int x) { return Int(v = Mod(v - x)); }
    		inline Int operator *=(Int x) { return Int(v = 1LL * v * x.v % mod); }
    		inline Int operator *=(int x) { return Int(v = 1LL * v * x % mod); }
    		inline Int operator /=(Int x) { return Int(v = v / x.v); }
    		inline Int operator /=(int x) { return Int(v = v / x); }
    		inline Int operator ^(int k) { return Int(power(v, k)); }
    } ;
    
    using mint = Int<(int) (1e9 + 7)>;
    
    const int N = 2e5 + 10;
    const mint inv2 = ~ mint(2);
    
    int n;
    int a[N], b[N];
    mint cnt;
    
    struct Node {
    	Node *ls, *rs;
    	mint cj, tg;
    	int l, r;
    	Node() {}
    	Node(int _l, int _r) : l(_l), r(_r), cj(0), tg(1), ls(NULL), rs(NULL) {}
    	void upd() {
    		cj = ls -> cj + rs -> cj;
    	}
    	void downcj(mint v) { cj *= v; tg *= v; }
    	void pushdown() {
    		if(tg != 1) {
    			ls -> downcj(tg);
    			rs -> downcj(tg);
    			tg = 1;
    		}
    	}
    	void modify(int pos, mint v) {
    		if(l == r) { cj = v; return ; }
    		pushdown();
    		int mid = (l + r) >> 1;
    		if(pos <= mid) ls -> modify(pos, v);
    		else rs -> modify(pos, v);
    		upd();
    	}
    	mint qry(int L, int R) {
    		if(L <= l && r <= R) return cj;
    		pushdown();
    		int mid = (l + r) >> 1;
    		mint res = 0;
    		if(L <= mid) res += ls -> qry(L, R);
    		if(R > mid) res += rs -> qry(L, R);
    		return res;
    	}
    	void mul(mint v) { downcj(v); return ; }
    } ;
    
    Node *root;
    
    Node *build(int l, int r) {
    	Node * x = new Node(l, r);
    	if(l == r) return x;
    	int mid = (l + r) >> 1;
    	x -> ls = build(l, mid);
    	x -> rs = build(mid + 1, r);
    	return x;
    }
    
    using pii = pair<int, int>;
    
    int c[N];
    #define lowbit(x) (x & -x)
    void upd(int x, int y) {
    	for(; x <= n; x += lowbit(x)) c[x] += y;
    }
    int qry(int x) {
    	int ans = 0;
    	for(; x; x -= lowbit(x)) ans += c[x];
    	return ans;
    }
    
    int main() {
    	cin >> n;
    	for(int i = 1; i <= n; i ++) cin >> a[i];
    	for(int i = 1; i <= n; i ++) b[a[i]] ++;
    	for(int i = n; i >= 1; i --) b[i] += b[i + 1];
    	cnt = 1;
    	for(int i = 1; i <= n; i ++) cnt = cnt * (b[i] - (n - i));
    	root = build(1, n);
    	mint ans = 0;
    	static vector<int> lim[N];
    	for(int i = 1; i <= n; i ++) lim[a[i]].push_back(i);
    	for(int i = 1; i <= n; i ++) {
    		mint value = b[i] - (n - i) - 1;
    		value = value * (~ (value + 1));
    		root -> mul(value);
    		for(int j : lim[i]) ans += root -> qry(1, j);
    		for(int j : lim[i]) root -> modify(j, cnt * inv2);
    		mint t = lim[i].size();
    		ans += t * (t - 1) * inv2 * cnt * inv2;
    	}
    	//cout << ans << endl;
    	for(int i = n; i >= 1; i --) {
    		ans += cnt * qry(a[i] - 1);
    		upd(a[i], 1); 
    	}
    //	cerr << ans << endl;
    	root = build(1, n);
    	for(int i = 1; i <= n; i ++) {
    		mint value = b[i] - (n - i) - 1;
    		value = value * (~ (value + 1));
    		root -> mul(value);
    		for(int j : lim[i]) ans -= root -> qry(j, n);
    		for(int j : lim[i]) root -> modify(j, cnt * inv2);
    		//mint t = lim[i].size();
    		//ans -= t * (t - 1) * inv2 * cnt * inv2;
    	}
    	cout << ans << endl;
    	return 0;
    }
    
    inline int log2(unsigned int x) { return __builtin_ffs(x); }
    inline int popcount(unsigned int x) { return __builtin_popcount(x); }
    inline int popcount(unsigned long long x) { return __builtin_popcountl(x); }
    
  • 相关阅读:
    腾讯TBS加载网页无法自适应记录
    filter过滤器实现验证跳转_返回验证结果
    Oracle不连续的值,如何实现查找上一条、下一条
    springmvc.xml 中 <url-pattern></url-pattern>节点详解
    spring拦截器-过滤器的区别
    (转)spring中的拦截器(HandlerInterceptor+MethodInterceptor)
    @Value("${xxxx}")注解的配置及使用
    mybatis BindingException: Invalid bound statement (not found)
    spring声明式事务管理方式( 基于tx和aop名字空间的xml配置+@Transactional注解)
    Spring事务管理详解_基本原理_事务管理方式
  • 原文地址:https://www.cnblogs.com/clover4/p/15304569.html
Copyright © 2011-2022 走看看