上图来源于标程解析 CF
1 #include<iostream>
2 #include<algorithm>
3 #include<cstdio>
4 #include<vector>
5 #define ll long long
6 #define fir first
7 #define sec second
8 using namespace std;
9 typedef pair<int, int> pii;
10 const int N = 2e3 + 10;
11 const int mod = 1e9 + 7;
12 ll n, m, r;
13 ll b[N];
14
15 ll sum[N][N];//与原点包围的面积内蛇的个数
16
17 ll max(ll a, ll b)
18 {
19 if(a > b) return a;
20 return b;
21 }
22
23 ll min(ll a, ll b)
24 {
25 if(a > b) return b;
26 return a;
27 }
28
29 ll quick_pow(ll base, ll k)
30 {
31 ll res = 1;
32 while(k)
33 {
34 if(k & 1){
35 res *= base % mod;
36 res %= mod;
37 }
38 base *= base;
39 base %= mod;
40 k >>= 1;
41 }
42 return res % mod;
43 }
44
45 ll inv[N << 1];
46 ll f[N << 1];
47
48 void pre()//预处理
49 {
50 f[0] = inv[0] = 1;
51 for(ll i = 1 ; i < N * 2 ; i++){
52 f[i] = 1ll * f[i - 1] * i % mod;//阶乘
53 inv[i] = quick_pow(i, mod - 2) % mod * inv[i - 1] % mod;//求阶乘的逆元
54 }
55 }
56
57 ll getsum2(int x1, int y1, int x2, int y2)
58 {
59 x1 = max(1, x1);
60 y1 = max(1, y1);
61 x2 = min(1000, x2);
62 y2 = min(1000, y2);
63 if(x1 > x2 || y2 < y1) return 0;
64 ll tmp = sum[x2][y2] - sum[x1 - 1][y2] - sum[x2][y1 - 1] + sum[x1 - 1][y1 - 1];
65 return tmp;
66 }
67
68 ll getsum1(int x1, int y1, int r)
69 {
70 return getsum2(x1 - r, y1 - r, x1 + r, y1 + r);
71 }
72
73 ll comb(ll n, ll m)
74 {
75 if(m > n || m < 0) return 0;
76 return f[n] % mod * inv[n - m] % mod * inv[m] % mod;
77 }
78
79 int main(){
80 pre();
81 scanf("%lld%lld%lld",&n,&m,&r);
82 vector<pii> snakes;
83
84 for(int i = 0 ; i < n ; i++){
85 int x, y;
86 scanf("%d%d%lld",&x,&y,&b[i]);
87 sum[x][y]++;
88 snakes.push_back(make_pair(x, y));
89 }
90
91 for(int i = 1 ; i < N ; i++){
92 for(int j = 1 ; j < N ; j++){
93 sum[i][j] += sum[i - 1][j] + sum[i][j - 1] - sum[i - 1][j - 1];//容斥
94 }
95 }
96
97 ll res = 0;
98 int sz = (int)snakes.size();
99 for(int i = 0 ; i < sz ; i++){
100 for(int j = i ; j < sz ; j++){//
101 int x1 = snakes[i].fir;
102 int y1 = snakes[i].sec;
103 int x2 = snakes[j].fir;
104 int y2 = snakes[j].sec;
105
106 int X1 = max(x1 - r, x2 - r);
107 int Y1 = max(y1 - r, y2 - r);
108 int X2 = min(x1 + r, x2 + r);
109 int Y2 = min(y1 + r, y2 + r);
110
111 ll w = getsum2(X1, Y1, X2, Y2);//i死j死
112 ll u = getsum1(x1, y1, r) - w;//i死j不死
113 ll v = getsum1(x2, y2, r) - w;//i不死j死
114
115 ll tmp = 0;
116 tmp += comb(n, m) - comb(n - w, m);//
117 if(tmp < 0) tmp += mod;
118
119 tmp += comb(n - w, m) - comb(n - u - w, m) - comb(n - v - w, m) + comb(n - u - v - w, m);//
120 tmp %= mod;
121 if(tmp < 0){
122 tmp += mod;
123 }
124
125 if(i == j){
126 res += tmp * b[i] % mod * b[j] % mod;
127 }else{
128 res += (2ll * tmp % mod * b[i] % mod * b[j] % mod) % mod;
129 }
130 res %= mod;
131 }
132 }
133 printf("%lld
",res % mod);
134
135 return 0;
136 }
最后的选取部分还没理解透彻
1.在w范围内出现的情况
2.不在w内出现,且两点均在u,v范围内出现的情况