好久之前在 cmd's blog 看到过,这次做题遇上了,就学了一下,其实挺 easy 的。
众所周知其实是我不会证 (n) 个点 ((x_i,y_i)) 可以唯一确定一个次数为 (n-1) 的多项式,拉格朗日插值给出了一种构造:
[f(z)=sum_{i=1}^{n} dfrac{y_iprod_{j
ot=i}(z-x_j)}{prod_{j
ot=i}(x_i-x_j)}
]
首先提出常数部分:
[a_i=dfrac{y_i}{prod_{j
ot=i}(x_i-x_j)}
]
可以 (O(n^2)) 搞出每一个 (a_i)。
然后求一个多项式 (g(z)=prod_{i=1}^{n} (z-x_i))。
可以发现
[f(z)=sum_{i=1}^{n}a_idfrac{g(z)}{z-x_i}
]
考虑如何快速搞出后面那个 (dfrac{g(z)}{z-x_i})。
设 (h(z)=dfrac{g(z)}{z-c})。
可以得到 ((z-c)h(z)=g(z))。两边提取系数得到
[[z^{i-1}]h-c[z^i]h=[z^i]g\
[z^i]h=dfrac{[z^i]g-[z^{i-1}]h}{-c}
]
递推即可。
给出 模板题 通过代码:
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mkp(x,y) make_pair(x,y)
#define pb(x) push_back(x)
#define sz(v) (int)v.size()
typedef long long LL;
typedef double db;
template<class T>bool ckmax(T&x,T y){return x<y?x=y,1:0;}
template<class T>bool ckmin(T&x,T y){return x>y?x=y,1:0;}
#define rep(i,x,y) for(int i=x,i##end=y;i<=i##end;++i)
#define per(i,x,y) for(int i=x,i##end=y;i>=i##end;--i)
inline int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f?x:-x;
}
#define mod 998244353
inline int qpow(int n, int k) {
int res = 1;
for(; k; k >>= 1, n = 1ll * n * n % mod)
if(k & 1) res = 1ll * n * res % mod;
return res;
}
vector <int> lagrange(const vector <int> &x, const vector <int> &y) {
assert(x.size() == y.size());
int n = x.size();
vector <int> a(n, 0), b(n + 1, 0), c(n + 1, 0), f(n, 0);
for(int i = 0; i < n; ++i) {
int A = 1;
for(int j = 0; j < n; ++j) if(i != j)
A = 1ll * A * (x[i] - x[j] + mod) % mod;
a[i] = 1ll * qpow(A, mod - 2) * y[i] % mod;
}
b[0] = 1;
for(int i = 0; i < n; ++i) {
for(int j = i + 1; j >= 1; --j)
b[j] = (1ll * b[j] * (mod - x[i]) + b[j - 1]) % mod;
b[0] = 1ll * b[0] * (mod - x[i]) % mod;
}
for(int i = 0; i < n; ++i) {
int iv = qpow(mod - x[i], mod - 2);
if(!iv) {
for(int j = 0; j < n; ++j) c[j] = b[j + 1];
} else {
c[0] = 1ll * b[0] * iv % mod;
for(int j = 1; j <= n; ++j)
c[j] = 1ll * (b[j] + mod - c[j - 1]) * iv % mod;
}
for(int j = 0; j < n; ++j)
f[j] = (f[j] + 1ll * a[i] * c[j] % mod) % mod;
}
return f;
}
inline int calc(const vector <int> &f, int x) {
int res = 0;
for(int i = f.size() - 1; i >= 0; --i) res = (1ll * res * x + f[i]) % mod;
return res;
}
signed main() {
int n = read(), k = read();
vector <int> x(n), y(n);
for(int i = 0; i < n; ++i) x[i] = read(), y[i] = read();
vector <int> f = lagrange(x, y);
cout << calc(f, k) << '
';
}