作用
实际上是求一些关于整除的和式,比如
[sum_{i=0}^nlfloorfrac{ai+b}{c}
floor ,
sum_{i=0}^nlfloorfrac{ai+b}{c}
floor^2 ,
sum_{i=0}^nilfloorfrac{ai+b}{c}
floor
]
(这里的三个例子来自luoguP5170)
虽然叫这个名字,但其实跟 gcd 没什么关系,只是复杂度和复杂度分析相同。所以复杂度是 (O(logmax(a,c))) 。
第一个式子实际上是求直线下方在第一象限内的整点数
算法思想
没啥思想,就是推式子,运用一些基本的数学技巧。
推上面那三个式子需要的 trick:
[lfloorfrac{ai+b}{c}
floor=lfloorfrac{(amod c)i+(b mod c)}{c}
floor + ilfloorfrac{a}{c}
floor+lfloorfrac{b}{c}
floor \
X^2=2sum_{d=0}^nd-X
]
第一个式子非常好证明,只需要把 a 和 b拆成 kc+r 的形式即可。
第二个式子是通过添加一个求和来消除二次方,同理三次方也可以用类似方法消掉。之后可以通过交换求和顺序来化简。
这种东西分为两类来考虑: a 或 b 大于等于 c ;a 和 b 都小于 c。
(a ge c or b ge c)
运用第一个式子很容易从 (a,b,c,n) 递归到 (a%c,b%c,c,n) 。
(a,b < c)
此时发现 (lfloorfrac{ai+b}{c} floorle n) ,可以用这个缩小范围。令 (m=lfloorfrac{an+b}{c} floor) :
[f(a,b,c,n) = sum_{i=0}^nlfloorfrac{ai+b}{c}
floor \
= sum_{i=0}^nsum_{d=0}^{m-1}[lfloorfrac{ai+b}{c}
floor ge d+1] \
= sum_{d=0}^{m-1}sum_{i=0}^n[i>lfloorfrac{cd+c-b-1}{a}
floor] \
= sum_{d=0}^{m-1}n - lfloorfrac{cd+c-b-1}{a}
floor\
= n*m - f(c,c-b-1,a,m-1)
]
重点其实在第一步,这个谓词很关键(实际上是 (X=sum_{d=1}^X1))。后面的就很显然了。再推一个。
[g(a,b,c,n) = sum_{i=0}^nlfloorfrac{ai+b}{c}
floor^2 \
= sum_{i=0}^n2sum_{d=1}^{lfloorfrac{ai+b}{c}
floor}d - lfloorfrac{ai+b}{c}
floor\
= 2sum_{d=0}^{m-1}d+1sum_{i=0}^n[lfloorfrac{ai+b}{c}
floor ge d+1] - f(a,b,c,d)
]
这里用到了上述的第二个式子。其他的都类似,不再赘述。
应用
其中
[f=sum_{i=0}^nlfloorfrac{ai+b}{c}
floor ,
g=sum_{i=0}^nlfloorfrac{ai+b}{c}
floor^2 ,h=
sum_{i=0}^nilfloorfrac{ai+b}{c}
floor
]
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef long double LD;
typedef pair<int,int> pii;
typedef pair<LL,int> pli;
const int SZ = 1e6 + 10;
const LL INF = 1e15 + 10;
const int mod = 998244353;
const LD eps = 1e-8;
LL read() {
LL n = 0;
char a = getchar();
bool flag = 0;
while(a > '9' || a < '0') { if(a == '-') flag = 1; a = getchar(); }
while(a <= '9' && a >= '0') { n = n * 10 + a - '0',a = getchar(); }
if(flag) n = -n;
return n;
}
struct node {
LL f,g,h;
};
struct LO {
const LL inv2 = 499122177;
const LL inv6 = 166374059;
LL F1(LL n) {
n %= mod;
return n * (n+1) % mod * inv2 % mod;
}
LL F2(LL n) {
n %= mod;
return n * (n+1) % mod * (2*n+1) % mod * inv6 % mod;
}
node solve(LL a,LL b,LL c,LL n) {
LL f,g,h;
if(a == 0) {
n %= mod;
f = (b/c%mod) * (n+1) % mod;
g = (b/c%mod) * (b/c%mod)%mod * (n+1) % mod;
h = (b/c%mod) * F1(n) % mod;
}
else if(a>=c || b>=c) {
node t = solve(a%c,b%c,c,n);
f = (t.f + F1(n)*(a/c%mod)%mod + ((n+1)%mod)*(b/c%mod)%mod) % mod;
g = (F2(n)*(a/c%mod)%mod*(a/c%mod)%mod + ((n+1)%mod)*(b/c%mod)%mod*(b/c%mod)%mod + t.g
+ 2*F1(n)*(a/c%mod)%mod*(b/c%mod)%mod + 2*(a/c%mod)*t.h%mod + 2*(b/c%mod)*t.f%mod) % mod;
h = (t.h + F2(n)*(a/c%mod)%mod + F1(n)*(b/c%mod)%mod) % mod;
}
else {
LL m = (a*n+b)/c;
node t = solve(c,c-b-1,a,m-1);
m %= mod;
f = (n*m%mod - t.f) % mod;
g = (2*n*F1(m)%mod - 2*(t.h+t.f) - f) % mod;
h = (F1(n)*m%mod - (t.g+t.f)*inv2%mod) % mod;
}
f += mod; f %= mod;
g += mod; g %= mod;
h += mod; h %= mod;
return (node){f,g,h};
}
}lo;
int main() {
int T = read();
while(T --) {
LL n = read(),a = read(),b = read(),c = read();
node ans = lo.solve(a,b,c,n);
printf("%lld %lld %lld
",ans.f,ans.g,ans.h);
}
}