点分治入门题
距离上一次写点分治好像快有两年了呢…那时候好像快省选吧,在济南培训,离别的前夕写的。
后来一直没练过。今天拿出来写个入门题,假装我没学过了。
代码写得一团乱麻。~%?…,# *'☆&℃$︿★?
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 200005;
vector <pair<int,int> > g[N];
int n,k,vis[N],dis[N],dep[N],siz[N],mxsiz[N],u[N],ans=1e+12;
vector <int> vec_1;
int mp[N*10];
vector <pair<int,int> > vec_2;
vector <int> mppos;
vector <int> vipos;
int cnt = 0;
void reset_vis() {
for(int i=0;i<vipos.size();i++) {
vis[vipos[i]]=0;
}
vipos.clear();
}
int dfs1(int p){
vis[p]=1; vipos.push_back(p);
siz[p]=1;
mxsiz[p]=0;
vec_1.push_back(p);
for(int i=0;i<g[p].size();i++) {
int q=g[p][i].first, w=g[p][i].second;
if(vis[q]==0 && u[q]==0) {
dfs1(q);
siz[p]+=siz[q];
mxsiz[p]=max(mxsiz[p],siz[q]);
}
}
//mxsiz[p]=max(mxsiz[p],n-siz[p]);
}
int FindRoot(int p) {
reset_vis();
vec_1.clear();
dfs1(p);
int mv=1e+9, mx=0;
for(int i=0;i<vec_1.size();i++) {
int q=vec_1[i];
mxsiz[q]=max(mxsiz[q],(long long)vec_1.size()-siz[q]);
if(mxsiz[q] < mv) {
mv = mxsiz[q];
mx = q;
}
}
vec_1.clear();
return mx;
}
void dfs2(int p) {
vis[p]=1; vipos.push_back(p);
for(int i=0;i<g[p].size();i++) {
int q=g[p][i].first, w=g[p][i].second;
if(vis[q]==0 && u[q]==0) {
dis[q]=dis[p]+w;
dep[q]=dep[p]+1;
dfs2(q);
}
}
vec_2.push_back(make_pair(dis[p],dep[p]));
if(k-dis[p]>=0 && mp[k-dis[p]]<1e+6) ans=min(ans,dep[p]+mp[k-dis[p]]);
}
void reset_map() {
for(int i=0;i<mppos.size();i++) {
mp[mppos[i]]=1e+16;
}
mppos.clear();
}
void solve(int p) {
if(u[p]) return;
int r = FindRoot(p);
//cout<<r<<endl;
u[r]=1;
reset_vis();
reset_map();
mp[0]=0;
for(int i=0;i<g[r].size();i++) {
int q=g[r][i].first, w=g[r][i].second;
if(vis[q]==0 && u[q]==0) {
dis[q]=w;
dep[q]=1;
dfs2(q);
for(int j=0;j<vec_2.size();j++) {
int x=vec_2[j].first, y=vec_2[j].second;
if(x>1e+6) continue;
mp[x]=min(mp[x],y);
mppos.push_back(x);
}
vec_2.clear();
}
}
for(int i=0;i<g[r].size();i++) {
int q=g[r][i].first;
if(!u[q]) solve(q);
}
}
signed main() {
scanf("%lld%lld",&n,&k);
memset(mp,0x3f,sizeof mp);
if(k==0) {
cout<<0<<endl;
return 0;
}
for(int i=1;i<n;i++) {
int u,v,w;
scanf("%lld%lld%lld",&u,&v,&w);
++u; ++v;
g[u].push_back(make_pair(v,w));
g[v].push_back(make_pair(u,w));
}
solve(1);
if(ans >= 1e+12) cout<<-1<<endl;
else cout<<ans<<endl;
return 0;
}