考场上自己yy出来的做法.....
Code:
#include<cstdio> #include<algorithm> #include<queue> #include<vector> #include<string> using namespace std; void setIO(string a){ freopen((a+".in").c_str(),"r",stdin); freopen((a+".out").c_str(),"w",stdout); } void shutIO(){ fclose(stdin); fclose(stdout); } #define ll long long #define maxn 200009 struct Node{ ll dist; int u; Node(ll dist=0,int u=0):dist(dist),u(u){} bool operator<(Node a) const{ return a.dist>dist; } }; priority_queue<Node>Q; int p[maxn], tag[maxn],n,m,k,cnt,val[maxn],head[maxn],to[maxn<<1],nex[maxn<<1]; ll sumv[maxn]; int find(int x){ return p[x]==x?x:p[x]=find(p[x]); } void merge(int a,int b){ int x=find(a),y=find(b); if(x==y) return; p[x]=y; } void addedge(int u,int v){ nex[++cnt]=head[u],head[u]=cnt,to[cnt]=v; } bool check(int a){ int x=find(a); if(tag[x]) return false; tag[x]=1; return true; } void dfs(int u,int fa){ sumv[u]=(ll)val[u]; int flag=0; ll MAX=0; for(int v=head[u];v;v=nex[v]){ if(to[v]==fa) continue; dfs(to[v],u); if(sumv[to[v]]>MAX) MAX=sumv[to[v]], flag=to[v]; } if(flag){ sumv[u]+=MAX, merge(u,flag); for(int v=head[u];v;v=nex[v]){ if(to[v]==fa) continue; if(to[v]!=flag) Q.push(Node(sumv[to[v]],to[v])); } } } int main(){ //setIO("game"); scanf("%d%d",&n,&k); for(int i=1;i<=n;++i) scanf("%d",&val[i]); for(int i=1;i<=n;++i) p[i]=i; for(int i=1;i<n;++i){ int a,b; scanf("%d%d",&a,&b); addedge(a,b),addedge(b,a); } dfs(1,0); Q.push(Node(sumv[1],1)); int cur=0; ll fin=0; while(!Q.empty()&&cur<k){ Node a=Q.top();Q.pop(); if(check(a.u))fin+=a.dist,cur+=1; } printf("%lld",fin); //shutIO(); return 0; }