问题描述
二维平面上,给定N(大约200000)个点,这些点的x和y的取值范围都是[1,3000]之间的整数,给定M(大约200000)个查询,每个查询输入一个点P(px,py)。对于每个查询,求N个点到点P的距离之和。
输入
第一行两个整数N和M
接下来N行表示N个点的x和y坐标
接下来M行表示M个查询的x和y坐标
3 3
5 5
5 10
10 5
1 1
5 5
10 10
输出
M个正整数,每个查询输出一个正整数
思路
二维前缀和
定义二维数组xsum[3001][3001],xsum[x,y]表示区域[0,x],[0,y]上x坐标之和。这样一来,任意区域[fx,tx],[fy,ty]的x坐标之和就可以表示为xsum[fy,ty]-xsum[fx,ty]-xsum[tx,fy]+xsum[fx,tx]。
对于每个查询的点P,只需要处理它的左上方、左下方、右上方、右下方四个区域中的距离之和(方向使用二维空间直角坐标系)。
例如,点P左上方的点数为K,点P左上方的x坐标之和、y坐标之和分别为xs,ys,则点P左上方的点到P的距离为:K*px-xs+ys-K*py
代码
#include<stdio.h>
#include<iostream>
using namespace std;
const int maxn = 200007;
const int ma = 3003;
int a[ma][ma];
int xsum[ma][ma], ysum[ma][ma];
int c[ma][ma];
//前开后闭区间
int getxsum(int fx, int fy, int tx, int ty) {//前开后闭区间
return xsum[tx][ty] - xsum[fx][ty] - xsum[tx][fy] + xsum[fx][fy];
}
int getysum(int fx, int fy, int tx, int ty) {
return ysum[tx][ty] - ysum[fx][ty] - ysum[tx][fy] + ysum[fx][fy];
}
int getcnt(int fx, int fy, int tx, int ty) {
return c[tx][ty] - c[tx][fy] - c[fx][ty] + c[fx][fy];
}
int main() {
freopen("in.txt", "r", stdin);
int N, M;
cin >> N >> M;
memset(a, 0, sizeof(a));
memset(xsum, 0, sizeof(xsum));
memset(ysum, 0, sizeof(ysum));
memset(c, 0, sizeof(c));
for (int i = 0; i < N; i++) {
int x, y;
cin >> x >> y;
a[x][y]++;
}
for (int x = 1; x < ma; x++) {
for (int y = 1; y < ma; y++) {
xsum[x][y] = a[x][y] * x + xsum[x][y - 1]+xsum[x-1][y]-xsum[x-1][y-1];
ysum[x][y] = a[x][y] * y + ysum[x][y - 1]+ysum[x-1][y]-ysum[x-1][y-1];
c[x][y] = a[x][y]-c[x-1][y-1]+c[x-1][y]+c[x][y-1];
}
}
for (int i = 0; i < M; i++) {
int x, y;
cin >> x >> y;
int s = 0;
int cnt = getcnt(0, 0, x , y );
s += cnt * x + cnt * y - getxsum(0, 0, x , y) - getysum(0, 0, x, y );
cnt = getcnt(x , 0, ma-1, y );
s += getxsum(x, 0, ma-1, y) - cnt*x + cnt * y - getysum(x , 0, ma-1, y );
cnt = getcnt(0, y , x, ma-1);
s += getysum(0, y, x, ma-1) - y * cnt + cnt * x - getxsum(0, y, x, ma-1);
cnt = getcnt(x , y , ma-1, ma-1);
s += getxsum(x , y , ma-1, ma-1) + getysum(x , y , ma-1, ma-1) - cnt * x - cnt * y;
cout << s << endl;
}
return 0;
}
总结
多维前缀和的计算方式可以认为是容斥原理。对于三维前缀和,加减号可以直接通过二进制表示的奇偶性来表示。
下面来段代码测试一下这个思路
import java.util.Random;
public class Main {
class Point {
int x, y, z;
Point(int x, int y, int z) {
this.x = x;
this.y = y;
this.z = z;
}
@Override
public String toString() {
return String.format("(%d,%d,%d)", x, y, z);
}
}
final int N = 100;//100*100*100的三维空间内
final int POINT_COUNT = 4000;//点的个数
final int QUESTION_COUNT = 4000;//问题的个数
Random r = new Random(0);
//随机一个点,点的各个维度取值范围要在1到N-1之间
Point randomPoint() {
return new Point(r.nextInt(N - 2) + 1, r.nextInt(N - 2) + 1, r.nextInt(N - 2) + 1);
}
//生成问题
Point[] generateProblem() {
Point[] a = new Point[POINT_COUNT];
for (int i = 0; i < a.length; i++) {
a[i] = randomPoint();
}
return a;
}
//绝对正确的方法
class Stupid {
Point[] a;
Stupid(Point[] a) {
this.a = a;
}
int solve(Point fp, Point tp) {
int s = 0;
for (Point i : a) {
if (i.x >= fp.x && i.y >= fp.y && i.z >= fp.z && i.x <= tp.x && i.y <= tp.y && i.z <= tp.z) {
s++;
}
}
return s;
}
}
//快速方法
class Fast {
int[][][] a, c;
Fast(Point[] p) {
//a[i,j,k]表示i,j,k处的点的个数
a = new int[N][N][N];
//c[i,j,k]表示000到ijk处的点的总数
c = new int[N][N][N];
for (Point i : p) {
a[i.x][i.y][i.z]++;
}
for (int i = 1; i < N; i++) {
for (int j = 1; j < N; j++) {
for (int k = 1; k < N; k++) {
c[i][j][k] = c[i - 1][j][k] + c[i][j - 1][k] + c[i][j][k - 1] - c[i - 1][j - 1][k] - c[i - 1][j][k - 1] - c[i][j - 1][k - 1] + c[i - 1][j - 1][k - 1] - a[i][j][k];
}
}
}
}
int solve(Point fp, Point tp) {
int s = 0;
for (int i = 0; i < 8; i++) {
//i各个bit
int one = i & 1, two = (i >> 1) & 1, three = (i >> 2) & 1;
int x = tp.x, y = tp.y, z = tp.z;
//符号位
int sgn = (one ^ two ^ three) == 0 ? -1 : 1;
if (one != 0) x = fp.x - 1;
if (two != 0) y = fp.y - 1;
if (three != 0) z = fp.z - 1;
s += c[x][y][z] * sgn;
}
return s;
}
}
Main() {
Point[] p = generateProblem();
Stupid stupid = new Stupid(p);
Fast fast = new Fast(p);
for (int i = 0; i < QUESTION_COUNT; i++) {
//生成一对点作为查询区间,起始点的各个坐标必须小于终结点的各个坐标
Point fp = randomPoint(), tp = randomPoint();
if (fp.x > tp.x) {
int temp = fp.x;
fp.x = tp.x;
tp.x = temp;
}
if (fp.y > tp.y) {
int temp = fp.y;
fp.y = tp.y;
tp.y = temp;
}
if (fp.z > tp.z) {
int temp = fp.z;
fp.z = tp.z;
tp.z = temp;
}
int realAns = stupid.solve(fp, tp);
int mine = fast.solve(fp, tp);
if (realAns != mine) {
throw new RuntimeException("error on from=" + fp + ",to=" + tp + " " + realAns + " " + mine);
}
}
}
public static void main(String[] args) {
new Main();
}
}