题意:
给一棵N个点的树,对应于一个长为N的全排列,对于排列的每个相邻数字a和b,他们的贡献是对应树上顶点a和b的路径长,求所有排列的贡献和
思路:
对每一条边,边左边有x个点,右边有y个点,x+y=n,权值为w,则答案为$displaystyle sum 2xyw(n-1)!=sum 2x(n-x)w(n-1)!$
其中每条边的x可以通过一次dfs找子树节点个数
比赛的时候找不到怎么又存权值又存子树节点个数的方法,瞎瘠薄调最后卡着空间时间过的,后来看别人代码学到了新方法:
比赛代码:
842MS | 14012K |
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #include<string> #include<stack> #include<queue> #include<deque> #include<set> #include<vector> #include<map> #include<functional> #define fst first #define sc second #define pb push_back #define mem(a,b) memset(a,b,sizeof(a)) #define lson l,mid,root<<1 #define rson mid+1,r,root<<1|1 #define lc root<<1 #define rc root<<1|1 #define lowbit(x) ((x)&(-x)) using namespace std; typedef double db; typedef long double ldb; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> PI; typedef pair<ll,ll> PLL; const db eps = 1e-6; const int mod = 1e9+7; const int maxn = 2e5+2; const int maxm = 2e6+100; const int inf = 0x3f3f3f3f; const db pi = acos(-1.0); vector<int>v[maxn]; ll w[maxn]; int f[maxn]; ll ans = 0; ll cnt[maxn]; ll dfs(int x, int fa){ ll tmp = 1; int sz = v[x].size(); if(sz==1&&fa!=-1)return f[x]=1; for(int i = 0; i < sz; i++){ if(v[x][i]!=fa)tmp += dfs(v[x][i], x); } return f[x]=tmp; } void init(){ cnt[1] = 1; for(int i = 2; i < 100000+1; i++){ cnt[i] = cnt[i-1]*i; cnt[i]%=mod; } return; } inline int read(){ int num; char ch; while((ch=getchar())<'0' || ch>'9'); num=ch-'0'; while((ch=getchar())>='0' && ch<='9'){ num=num*10+ch-'0'; } return num; } struct Edge{ int u; int v; int w; }edge[maxn]; int top = 1; void addedge(int u, int v, int w){ edge[top].u = u; edge[top].v = v; edge[top++].w = w; } int main() { int n; init(); while(scanf("%d", &n)!=EOF){ top = 1; ans = 0; mem(f, 0); for(int i = 1; i <= n; i++){ v[i].clear(); } for(int i = 1; i <= n-1; i++){ int x, y; x=read(); y=read(); v[x].pb(y); v[y].pb(x); int c; c = read(); addedge(x, y, c); //mp[x][y] = mp[y][x] = c; } //ddfs(1, -1); dfs(1, -1); f[1] = 0; for(int i = 1; i <= n-1; i++){ int x = edge[i].u; int y = edge[i].v; int ww = edge[i].w; //printf("%d %d %d %d ", x, f[x], y, f[y]); if(x==1){ w[y] = ww; continue; } else if(y==1){ w[x] = ww; continue; } if(f[x] < f[y]){ w[x] = ww; continue; } else{ w[y] = ww; continue; } } // for(int i = 1; i <= n; i++){ // printf("%d %d ", i, w[i]); // } for(int i = 2; i <= n; i++){ ll tmp = 1; tmp = 2*f[i]; tmp%=mod; tmp *= (n-f[i]); tmp%=mod; tmp *= cnt[n-1]; tmp %= mod; tmp *= (ll)w[i]; ans += tmp; ans %= mod; } printf("%I64d ", ans); } return 0; }
赛后代码:
499MS | 18972K |
#include<iostream> #include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #include<string> #include<stack> #include<queue> #include<deque> #include<set> #include<vector> #include<map> #include<functional> #define fst first #define sc second #define pb push_back #define mem(a,b) memset(a,b,sizeof(a)) #define lson l,mid,root<<1 #define rson mid+1,r,root<<1|1 #define lc root<<1 #define rc root<<1|1 #define lowbit(x) ((x)&(-x)) using namespace std; typedef double db; typedef long double ldb; typedef long long ll; typedef unsigned long long ull; typedef pair<int,int> PI; typedef pair<ll,ll> PLL; const db eps = 1e-6; const int mod = 1e9+7; const int maxn = 2e5+2; const int maxm = 2e6+100; const int inf = 0x3f3f3f3f; const db pi = acos(-1.0); ll d[maxn]; struct Edge{ int u; ll w; }; vector<Edge>v[maxn]; ll w[maxn]; int f[maxn]; int dfs(int x, int fa){ int sz = v[x].size(); int ans = 1; for(int i = 0; i < sz; i++){ if(v[x][i].u==fa)w[x]=v[x][i].w; else ans+=dfs(v[x][i].u, x); } if(sz==1&&fa!=-1)return f[x] = 1; return f[x] = ans; } int main() { int n; d[1] = 1; for(int i = 2; i < 100000 + 100; i++){ d[i] = d[i-1] *i; d[i] %= mod; } while(~scanf("%d", &n)){ for(int i = 1; i <= n; i++)v[i].clear(); //mem(f, 0); for(int i = 0; i < n-1; i++){ int x, y; ll l; scanf("%d %d %I64d", &x, &y, &l); Edge t1, t2; t1.u=y;t2.u=x; t1.w=t2.w=l; v[x].pb(t1); v[y].pb(t2); } ll ans = 0; dfs(1, -1); for(int i = 2; i <= n; i++){ ll tmp = d[n-1]; tmp %= mod; tmp *= (ll)2*f[i]*(n-f[i]); tmp %= mod; tmp *= w[i]; tmp %= mod; ans += tmp; ans %= mod; } printf("%I64d ", ans); } return 0; }