4543: [POI2014]Hotel加强版
分析:
f[u][i]表示子树u内,距离u为i的点的个数,g[u][i]表示在子树u内,已经选了两个深度一样的点,还需要在距离u为i的一个点作为第三个点。
然后就可以利用这两个数组统计答案了。
ans+=g[u][j]*f[v][j-1]+f[u][j]*g[v][j+1];
如果直接合并f和g,复杂度是$O(n^2)$的,如果可以启发式合并,复杂度是$O(nlogn)$的,如果是长链剖分,复杂度是$O(n)$的。
长链剖分就是按照子树内最长的链就行剖分,根节点保留深度最大的那个点的信息,其他的点暴力合并。
代码:
1 #include<cstdio> 2 #include<algorithm> 3 #include<cstring> 4 #include<iostream> 5 #include<cmath> 6 #include<cctype> 7 #include<set> 8 #include<queue> 9 #include<vector> 10 #include<map> 11 using namespace std; 12 typedef long long LL; 13 14 inline int read() { 15 int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-1; 16 for(;isdigit(ch);ch=getchar())x=x*10+ch-'0';return x*f; 17 } 18 19 const int N = 100005; 20 struct Edge{ int to, nxt; } e[N << 1]; 21 int head[N], len[N], fa[N], son[N], En; 22 LL ans, *f[N], *g[N], ft[N], gt[N << 1], *fp = ft, *gp = gt; 23 24 inline void add_edge(int u,int v) { 25 ++En; e[En].to = v, e[En].nxt = head[u]; head[u] = En; 26 ++En; e[En].to = u, e[En].nxt = head[v]; head[v] = En; 27 } 28 void dfs(int u) { 29 len[u] = 1; 30 for (int i = head[u]; i; i = e[i].nxt) { 31 int v = e[i].to; 32 if (v == fa[u]) continue; 33 fa[v] = u; 34 dfs(v); 35 if (len[v] + 1 > len[u]) len[u] = len[v] + 1, son[u] = v; 36 } 37 } 38 void solve(int u) { 39 f[u][0] = 1; 40 if (!son[u]) return ; 41 int v = son[u]; 42 g[v] = g[u] - 1; f[v] = f[u] + 1; 43 solve(v); 44 ans += g[u][0]; 45 for (int i = head[u]; i; i = e[i].nxt) { 46 v = e[i].to; 47 if (v == fa[u] || v == son[u]) continue; 48 g[v] = gp + len[v] + 1; gp += (len[v] * 2); 49 f[v] = fp + 1; fp += len[v]; 50 solve(v); 51 for (int j = 0; j <= len[v]; ++j) { 52 if (j) ans += g[u][j] * f[v][j - 1]; 53 ans += f[u][j] * g[v][j + 1]; 54 if (j) g[u][j] += f[u][j] * f[v][j - 1]; 55 g[u][j] += g[v][j + 1]; 56 if (j) f[u][j] += f[v][j - 1]; 57 } 58 } 59 } 60 int main() { 61 int n = read(); 62 for (int i = 1; i < n; ++i) { 63 int u = read(), v = read(); 64 add_edge(u, v); 65 } 66 dfs(1); 67 g[1] = gp + len[1] + 1; gp += (len[1] * 2); 68 f[1] = fp + 1; fp += len[1]; 69 solve(1); 70 cout << ans; 71 return 0; 72 }