考虑 一棵基环树, 只有对应在基环上同一点的点对是一条路径。我们目标是使一条路经的数目尽可能少。
dp[ u ] 表示从下延伸上来的一条最优链。 但是在当前这个点 u 作为交汇点的时候, 需要从它儿子中选两个
合起来得到最优值, 这个需要用斜率去优化, 但是好像被我用个很蠢的贪心水过去了。。 明天再补个斜率
优化的把。。
#include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0); using namespace std; const int N = 5e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T &a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T &a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T &a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T &a, S b) {return a > b ? a = b, true : false;} mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); int n; vector<int> G[N]; LL dp[N], sz[N]; LL mn, ans; void dfs(int u, int fa) { dp[u] = INF; sz[u] = 1; if(SZ(G[u]) == 1 && fa) { dp[u] = 0; return; } PLL mn1 = mk(INF, INF); PLL mn2 = mk(INF, INF); for(auto &v : G[u]) { if(v == fa) continue; dfs(v, u); sz[u] += sz[v]; if(mk(-sz[v], dp[v]) <= mn1) mn2 = mn1, mn1 = mk(-sz[v], dp[v]); else if(mk(-sz[v], dp[v]) < mn2) mn2 = mk(-sz[v], dp[v]); chkmin(mn, dp[v] + 1LL * (n - sz[v]) * (n - sz[v] - 1) / 2); } for(auto &v : G[u]) { if(v == fa) continue; LL tmp = dp[v] + (sz[u] - sz[v]) * (sz[u] - sz[v] - 1) / 2; chkmin(dp[u], tmp); } mn1.fi = -mn1.fi; mn2.fi = -mn2.fi; if(abs(mn1.fi) < INF && abs(mn2.fi) < INF) { chkmin(mn, mn1.se + mn2.se + 1LL * (n - mn1.fi - mn2.fi) * (n - mn1.fi - mn2.fi - 1) / 2); } } int main() { scanf("%d", &n); for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } ans = 1LL * n * (n - 1); mn = 1LL * n * (n - 1) / 2; dfs(1, 0); printf("%lld ", ans - mn); return 0; } /* */
维护直线板子
#include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0); using namespace std; const int N = 5e5 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T &a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T &a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T &a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T &a, S b) {return a > b ? a = b, true : false;} namespace LC { /** * Description: Container where you can add lines of the form kx+m, and query maximum values at points x. * Time: O(log N) */ struct Line { mutable LL k, m, p; bool operator < (const Line& o) const { return k < o.k; } bool operator < (LL x) const { return p < x; } }; struct LineContainer : multiset<Line, less<>> { // (for doubles, use inf = 1/.0, div(a,b) = a/b) const LL inf = LLONG_MAX; LL div(LL a, LL b) { // floored division return a / b - ((a ^ b) < 0 && a % b); } bool isect(iterator x, iterator y) { if (y == end()) { x->p = inf; return false; } if (x->k == y->k) x->p = x->m > y->m ? inf : -inf; else x->p = div(y->m - x->m, x->k - y->k); return x->p >= y->p; } void add(LL k, LL m) { auto z = insert({k, m, 0}), y = z++, x = y; while (isect(y, z)) z = erase(z); if (x != begin() && isect(--x, y)) isect(x, y = erase(y)); while ((y = x) != begin() && (--x)->p >= y->p) isect(x, erase(y)); } LL query(LL x) { assert(!empty()); auto l = *lower_bound(x); return l.k * x + l.m; } }; } int n; vector<int> G[N]; LL dp[N], sz[N]; LL mn, ans; LL c2(LL n) { return n * (n - 1) / 2; } void dfs(int u, int fa) { sz[u] = 1; for(auto &v : G[u]) { if(v == fa) continue; dfs(v, u); sz[u] += sz[v]; } dp[u] = c2(sz[u]); for(auto &v : G[u]) { if(v == fa) continue; chkmin(dp[u], dp[v] + c2(sz[u] - sz[v])); } for(auto &v : G[u]) { if(v == fa) continue; chkmin(mn, dp[v] + c2(n - sz[v])); } LC::LineContainer cont; for(auto &v : G[u]) { if(v == fa) continue; if(!cont.empty()) { chkmin(mn, dp[v] + c2(n) - n * sz[v] + c2(sz[v] + 1) - cont.query(sz[v])); } cont.add(-sz[v], -(dp[v] - n * sz[v] + c2(sz[v] + 1))); } } /* dp[u] + dp[v] + (n - sz[u] - sz[v]) * (n - sz[u] - sz[v] - 1) / 2 dp[u] + dp[v] + c2(n) - n * sz[u] - n * sz[v] + (sz[u] * sz[u] + sz[u] * sz[v] + sz[u] + sz[u] * sz[v] + sz[v] * sz[v] + sz[v]) / 2 dp[u] + dp[v] + c2(n) - n * sz[u] - n * sz[v] + c2(sz[u] + 1) + c2(sz[v] + 1) + sz[u] * sz[v] */ int main() { scanf("%d", &n); for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } ans = 1LL * n * (n - 1); mn = 1LL * n * (n - 1) / 2; dfs(1, 0); printf("%lld ", ans - mn); return 0; } /* */