@
题意
给你一颗(n(1e5))个点有边权有点权的树,(Min(u,v))表示(u,v)路径最小点权,(gcd(u,v))表示(u,v)路径点权的最大公因数,(dis(u,v))表示(u,v)路径大小。
输出(max(dis(u,v)*gcd(u,v)*Min(u,v)))
解析
- 法一:
- 外层枚举路径的gcd,并把两端点是gcd倍数的边存下,按照两端点较小的权值排序。
- 每次加一条边,并查集动态维护直径即可。路径最小点权就是当前加的边的最小点权。
- 因为平均每条边的因子个数不超过sqrt(10000)
- 所以总的枚举边的个数不超过n*sqrt(10000)
- 复杂度:(O(nlog(n) + n*sqrt(10000)))
- 法二:暴搜+剪枝
- 敢剪就敢过
- dia = 树的直径
- 剪枝:当diagcdmin <= ans时,return
- 先从直径的一个端点开始暴搜+剪枝预处理一遍。(不然会tle,可能先以直径为端点搜会搜到大答案的概率更大,剪枝效果更明显)
- 然后枚举起点开始暴搜,记录当前路径的gcd和min
- 法三:点分治
- 按重心分治,从重心开始搜索,记录每条路径的{second ancestor,dis, Gcd, Min}
- 按Min从小到大排序。这个排序很精髓。
- 因为所有路径都含有重心这个点,那么gcd的可能性就只有不到 sqrt (10000)种。
- 后面在枚举边的时候,动态对每种gcd记录最远的次远的路径,两条路径的次祖先不能相同。
- 先枚举边,再枚举gcd,看看这个gcd记录的两条路径能否和当前枚举的边组合出更优解。
- 复杂度肯定小于:(O(nlog(n) + n*sqrt(10000)))
AC_code
/*
* Codechef March Cook-Off 2018. Maximum Tree Path
* 法一:
* 外层枚举路径的gcd,并把两端点是gcd倍数的边存下,按照两端点较小的权值排序。
* 每次加一条边,并查集动态维护直径即可。路径最小点权就是当前加的边的最小点权。
* 因为平均每条边的因子个数不超过sqrt(10000)
* 所以总的枚举边的个数不超过n*sqrt(10000)
* 复杂度:O(nlog(n) + n*sqrt(10000))
* 法二:暴搜+剪枝
* 敢剪就敢过
* dia = 树的直径
* 剪枝:当dia*gcd*min <= ans时,return
* 先从直径的一个端点开始暴搜+剪枝预处理一遍。(不然会tle,可能先以直径为端点搜会搜到大答案的概率更大,剪枝效果更明显)
* 然后枚举起点开始暴搜,记录当前路径的gcd和min
* 法三:点分治
* */
#pragma comment(linker, "/STACK:102400000,102400000")
#include<bits/stdc++.h>
#define fi first
#define se second
#define endl '
'
#define o2(x) (x)*(x)
#define BASE_MAX 30
#define mk make_pair
#define eb emplace_back
#define all(x) (x).begin(), (x).end()
#define clr(a, b) memset((a),(b),sizeof((a)))
#define iis std::ios::sync_with_stdio(false); cin.tie(0)
#define my_unique(x) sort(all(x)),x.erase(unique(all(x)),x.end())
using namespace std;
#pragma optimize("-O3")
typedef long long LL;
typedef unsigned long long uLL;
typedef pair<int, int> pii;
inline LL read() {
LL x = 0;int f = 0;
char ch = getchar();
while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x = f ? -x : x;
}
inline void write(LL x, bool f) {
if (x == 0) {putchar('0'); if(f)putchar('
');else putchar(' ');return;}
if (x < 0) {putchar('-');x = -x;}
static char s[23];
int l = 0;
while (x != 0)s[l++] = x % 10 + 48, x /= 10;
while (l)putchar(s[--l]);
if(f)putchar('
');else putchar(' ');
}
int lowbit(int x) { return x & (-x); }
template<class T>T big(const T &a1, const T &a2) { return a1 > a2 ? a1 : a2; }
template<typename T, typename ...R>T big(const T &f, const R &...r) { return big(f, big(r...)); }
template<class T>T sml(const T &a1, const T &a2) { return a1 < a2 ? a1 : a2; }
template<typename T, typename ...R>T sml(const T &f, const R &...r) { return sml(f, sml(r...)); }
void debug_out() { cerr << '
'; }
template<typename T, typename ...R>void debug_out(const T &f, const R &...r) {cerr << f << " ";debug_out(r...);}
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]: ", debug_out(__VA_ARGS__);
const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
const int HMOD[] = {1000000009, 1004535809};
const LL BASE[] = {1572872831, 1971536491};
const int mod = 998244353;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int MXN = 5e5 + 7;
const int MXE = 1e6 + 7;
int n, m;
vector<pii > mp[MXN];
vector<int> has[MXN];
int ar[MXN];
pair<pii, int> cw[MXN];
int stk[MXN], top;
LL ans;
namespace LCA {
LL dis[MXN];
int up[MXN][20], lens[MXN];
int cnt, dfn[MXN], en[MXN], LOG[MXN];
void dfs(int u, int ba) {
lens[u] = lens[ba] + 1;
dfn[++cnt] = u;
en[u] = cnt;
for(auto V: mp[u]) {
int v = V.fi;
if(v == ba) continue;
dis[v] = dis[u] + V.se;
dfs(v, u);
dfn[++ cnt] = u;
}
}
inline int cmp(int u, int v) {
return lens[u] < lens[v] ? u: v;
}
void init() {
cnt = 0;
dis[0] = lens[0] = 0;
dfs(1, 0);
LOG[1] = 0;
for(int i = 2; i <= cnt; ++i) LOG[i] = LOG[i-1] + ((1<<(LOG[i-1]+1))==i);
for(int i = 1; i <= cnt; ++i) up[i][0] = dfn[i];
for(int j = 1; (1<<j) <= cnt; ++j)
for(int i = 1; i + (1<<j) -1 <= cnt; ++i)
up[i][j] = cmp(up[i][j-1], up[i+(1<<(j-1))][j-1]);
}
inline int lca(int x, int y) {
int l = en[x], r = en[y];
if(l > r) swap(l, r);
int k = LOG[r - l + 1];
return cmp(up[l][k], up[r-(1<<k)+1][k]);
}
inline LL query(int i, int j) {
return dis[i] + dis[j] - 2 * dis[lca(i, j)];
}
}
int fa[MXN];
pii data[MXN];
pii merge(pii A, pii B) {
int a[4];
a[0] = A.fi, a[1] = A.se, a[2] = B.fi, a[3] = B.se;
pii tmp = A;
LL res = 0;
for(int i = 0; i < 4; ++i) {
for(int j = i + 1; j < 4; ++j) {
LL ret = LCA::query(a[i], a[j]);
if(ret > res) {
res = ret;
tmp = mk(a[i], a[j]);
}
}
}
return tmp;
}
bool cmp(const int&a, const int&b) {
return cw[a].se > cw[b].se;
}
int Fi(int x) {
return fa[x] == x? x: fa[x] = Fi(fa[x]);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("/home/cwolf9/CLionProjects/ccc/in.txt", "r", stdin);
//freopen("/home/cwolf9/CLionProjects/ccc/out.txt", "w", stdout);
#endif
for(int i = 0; i <= 10000; ++i) stk[i] = i;
int tim = read();
while(tim --) {
n = read();
for(int i = 1; i <= n; ++i) mp[i].clear();
for(int i = 1; i <= n; ++i) ar[i] = read();
top = 0;
for(int i = 1, a, b, c; i < n; ++i) {
a = read(), b = read(), c = read();
cw[i] = mk(mk(a, b), sml(ar[a], ar[b]));
mp[a].eb(mk(b, c));
mp[b].eb(mk(a, c));
// stk[++top] = c;
c = __gcd(ar[a], ar[b]);
for(int j = 1; j * j <= c; ++j) if(c % j == 0) {
has[j].eb(i);
if(c/j != j) has[c/j].eb(i);
}
}
ans = 0;
LCA::init();
// sort(stk + 1, stk + 1 + top);
// top = unique(stk + 1, stk + 1 + top) - stk;
for(int i = 1; i <= 10000; ++i) {
if((int)has[stk[i]].size() <= 0)continue;
for(auto V: has[stk[i]]) {
fa[cw[V].fi.fi] = cw[V].fi.fi, fa[cw[V].fi.se] = cw[V].fi.se;
data[cw[V].fi.fi] = mk(cw[V].fi.fi, cw[V].fi.fi), data[cw[V].fi.se] = mk(cw[V].fi.se, cw[V].fi.se);
}
sort(all(has[stk[i]]), cmp);
for(auto V: has[stk[i]]) {
int pa = Fi(cw[V].fi.fi), pb = Fi(cw[V].fi.se);
fa[pa] = pb;
data[pb] = merge(data[pa], data[pb]);
ans = big(ans, LCA::query(data[pb].fi, data[pb].se) * (LL)cw[V].se * stk[i]);
}
}
printf("%lld
", ans);
for(int i = 1; i <= 10000; ++i) has[i].clear();
}
#ifndef ONLINE_JUDGE
cout << "time cost:" << clock() << "ms" << endl;
#endif
return 0;
}
const int MXN = 5e5 + 7;
const int MXE = 1e6 + 7;
int n, m;
int ar[MXN], sz[MXN], son[MXN], inde, fid[MXN], lid[MXN], rid[MXN], dep[MXN];
vector<pii > mp[MXN];
int stk[MXN], top, aim;
LL dia, dis[MXN], ans;
void dfs_sz(int u, int ba) {
sz[u] = 1;
son[u] = 0;
fid[u] = ++ inde;
rid[inde] = u;
for(auto V: mp[u]) {
if(V.fi == ba) continue;
dep[V.fi] = dep[u] + 1;
dfs_sz(V.fi, u);
sz[u] += sz[V.fi];
if(sz[V.fi] > sz[son[u]]) son[u] = V.fi;
}
lid[u] = inde;
}
void dfs(int u, int ba) {
for(auto V: mp[u]) {
if(V.fi == ba) continue;
dis[V.fi] = dis[u] + V.se;
dfs(V.fi, u);
}
}
void chk(int u, int ba, LL len, int Gcd, int Min) {
if(dia * Gcd * Min <= ans) return;
ans = big(ans, len * Gcd * Min);
for(auto V: mp[u]) {
if(V.fi == ba) continue;
chk(V.fi, u, len + V.se, __gcd(Gcd, ar[V.fi]), sml(Min, ar[V.fi]));
}
}
int main() {
int tim = read();
while(tim --) {
n = read();
for(int i = 1; i <= n; ++i) mp[i].clear();
for(int i = 1; i <= n; ++i) ar[i] = read();
inde = top = ans = 0;
for(int i = 1, a, b, c; i < n; ++i) {
a = read(), b = read(), c = read();
mp[a].eb(mk(b, c));
mp[b].eb(mk(a, c));
stk[++top] = c;
}
sort(stk + 1, stk + 1 + top);
top = unique(stk + 1, stk + 1 + top) - stk;
int S = 1, T = 1;
dis[1] = 0;
dfs(1, 0);
for(int i = 1; i <= n; ++i) if(dis[i] > dis[S]) S = i;
dis[S] = 0;
dfs(S, 0);
for(int i = 1; i <= n; ++i) if(dis[i] > dis[T]) T = i;
dia = dis[T];
chk(S, 0, 0, ar[S], ar[S]);
for(int i = 1; i <= n; ++i) chk(i, 0, 0, ar[i], ar[i]);
printf("%lld
", ans);
}
return 0;
}
点分治做法我还没补,大致思路和上面一样,先贴一下大佬的代码。
#include<bits/stdc++.h>
#define ll long long
#define ull unsigned ll
#define uint ungigned
#define db double
#define pii pair<int,int>
#define pll pair<ll,ll>
#define pli pair<ll,int>
#define vi vector<int>
#define vpi vector<pii >
#define IT iterator
#define PB push_back
#define MK make_pair
#define LB lower_bound
#define UB upper_bound
#define y1 wzpakking
#define fi first
#define se second
#define BG begin
#define ED end
#define For(i,j,k) for (int i=(int)(j);i<=(int)(k);i++)
#define Rep(i,j,k) for (int i=(int)(j);i>=(int)(k);i--)
#define UPD(x,y) (((x)+=(y))>=mo?(x)-=mo:233)
#define CLR(a,v) memset(a,v,sizeof(a))
#define CPY(a,b) memcpy(a,b,sizeof(a))
#define sqr(x) (1ll*x*x)
#define LS3 k*2,l,mid
#define RS3 k*2+1,mid+1,r
#define LS5 k*2,l,mid,x,y
#define RS5 k*2+1,mid+1,r,x,y
#define GET pushdown(k);int mid=(l+r)/2
#define INF (1<<29)
using namespace std;
int gcd(int x,int y){
return y?gcd(y,x%y):x;
}
const int N=100005;
vector<int> divi[N];
int head[N],vis[N],tot;
int sz[N],a[N],mx[N],rt;
int n;
ll ans;
struct edge{
int to,next,v;
}e[N*2];
struct node{
int fr;
ll mn,G,dis;
bool operator <(const node &a)const{
return mn<a.mn;
}
}g[N],g2[N];
vector<node> v;
void add(int x,int y,int v){
e[++tot]=(edge){y,head[x],v};
head[x]=tot;
}
void Dfs(int x,int fa,ll dis,int G,int mn,int fr){
mn=min(mn,a[x]); G=gcd(G,a[x]);
v.push_back((node){fr,mn,G,dis});
for (int i=head[x];i;i=e[i].next)
if (!vis[e[i].to]&&e[i].to!=fa)
Dfs(e[i].to,x,dis+e[i].v,G,mn,fr);
}
void solve(int x){
v.clear();
for (int i=head[x];i;i=e[i].next)
if (!vis[e[i].to])
Dfs(e[i].to,0,e[i].v,a[x],a[x],e[i].to);
sort(v.begin(),v.end());
for (auto i:divi[a[x]])
g[i]=g2[i]=(node){0,0,0,0};
Rep(i,v.size()-1,0){
int fr=v[i].fr;
ll mn=v[i].mn;
ll G=v[i].G;
ll dis=v[i].dis;
ans=max(ans,dis*G*mn);
for (auto di:divi[a[x]])
if (fr!=g[di].fr&&g[di].dis)
ans=max(ans,(dis+g[di].dis)*min(mn,g[di].mn)*gcd(G,di));
else if (g2[di].dis)
ans=max(ans,(dis+g2[di].dis)*min(mn,g2[di].mn)*gcd(G,di));
if (dis>g[G].dis){
if (g[G].fr!=fr) g2[G]=g[G];
g[G]=v[i];
}
else if (dis>g2[G].dis&&fr!=g[G].fr)
g2[G]=v[i];
}
}
void dfs(int x,int fa,int Sz){
mx[x]=0; sz[x]=1;
for (int i=head[x];i;i=e[i].next)
if (e[i].to!=fa&&!vis[e[i].to]){
dfs(e[i].to,x,Sz);
mx[x]=max(mx[x],sz[e[i].to]);
sz[x]+=sz[e[i].to];
}
mx[x]=max(mx[x],Sz-sz[x]);
if (mx[x]<mx[rt]) rt=x;
}
void divide(int x,int Sz){
rt=0; dfs(x,0,Sz);
vis[x=rt]=1; solve(x);
for (int i=head[x];i;i=e[i].next)
if (!vis[e[i].to]){
int nsz;
if (sz[e[i].to]>sz[x])
nsz=Sz-sz[x];
else nsz=sz[e[i].to];
divide(e[i].to,nsz);
}
}
void solve(){
ans=tot=0; mx[0]=1e9;
scanf("%d",&n);
For(i,0,n+1) sz[i]=head[i]=vis[i]=0;
For(i,1,n) scanf("%d",&a[i]);
For(i,1,n-1){
int x,y,v;
scanf("%d%d%d",&x,&y,&v);
add(x,y,v); add(y,x,v);
}
divide(1,n);
printf("%lld
",ans);
}
void init(){
For(i,1,10000) For(j,1,10000/i)
divi[i*j].PB(i);
}
int main(){
init();
int T;
scanf("%d",&T);
while (T--) solve();
}