题目描述
喜欢数学的wlswls最近被萎住了。
现在他一共有1...n1...n这么多数字,取数字ii会得到f[i]f[i]的收益。数字之间有些边,对于所有的i(i != 1)i(i!=1),若ii为奇数,则ii与3i+13i+1之间有边,否则ii与i/2i/2之间有边。如果一条边的两个顶点xyxy都被取了,那么会失去d[min(x, y)]d[min(x,y)]的价值。请问wlswls怎么取,才能使得收益最大?
输入描述
第一行一个整数nn。
接下来一行nn个整数表示ff。
接下来一行nn个整数表示dd。
1 leq n leq 1001≤n≤100
1 leq f[i], d[i] leq 10001≤f[i],d[i]≤1000
输出描述
输出一个整数表示答案。
样例输入 1
2 10 10 1 2
样例输出 1
19
思路:
根据题目给的建边条件,建边后会形成一个森林,然后把森林转化为一个0为根节点的树,随后进行树形dp。
定义状态:
dp[u][0/1] 0为 第u个节点的子树中不取第u个节点的最多利益值,
1为第u个节点的子树中取第u个节点的最多利益值,
常规的树形dp套路,
dp[u][0]+=max(dp[v][0],dp[v][1]);
dp[u][1]+=max(dp[v][0],dp[v][1]-d[min(u,v)]);// 一个边上两个节点都取的话,要减去对应的值。
最后max(dp[0][0],dp[0][1])就是我们的答案值。
细节见代码:
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <cmath> #include <queue> #include <stack> #include <map> #include <set> #include <vector> #include <iomanip> #define ALL(x) (x).begin(), (x).end() #define rt return #define dll(x) scanf("%I64d",&x) #define xll(x) printf("%I64d ",x) #define sz(a) int(a.size()) #define all(a) a.begin(), a.end() #define rep(i,x,n) for(int i=x;i<n;i++) #define repd(i,x,n) for(int i=x;i<=n;i++) #define pii pair<int,int> #define pll pair<long long ,long long> #define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0) #define MS0(X) memset((X), 0, sizeof((X))) #define MSC0(X) memset((X), ' ', sizeof((X))) #define pb push_back #define mp make_pair #define fi first #define se second #define eps 1e-6 #define gg(x) getInt(&x) #define db(x) cout<<"== [ "<<x<<" ] =="<<endl; using namespace std; typedef long long ll; ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;} ll lcm(ll a, ll b) {return a / gcd(a, b) * b;} ll powmod(ll a, ll b, ll MOD) {ll ans = 1; while (b) {if (b % 2)ans = ans * a % MOD; a = a * a % MOD; b /= 2;} return ans;} inline void getInt(int* p); const int maxn = 1010; const int inf = 0x3f3f3f3f; /*** TEMPLATE CODE * * STARTS HERE ***/ int pre[maxn]; int f[maxn]; int d[maxn]; int dp[maxn][2]; int n; int findpar(int x) { return pre[x] == 0 ? x : pre[x] = findpar(pre[x]); } void mer(int x, int y) { x = findpar(x); y = findpar(y); if (x != y) { pre[x] = y; } } std::vector<int> v[maxn]; int w; void dfs(int u, int pre) { // cout<<u<<" "<<pre<<endl; dp[u][0] = 0; dp[u][1] = f[u]; for (auto x : v[u]) { if (x != pre) { dfs(x, u); dp[u][0] += max(dp[x][0], dp[x][1]); dp[u][1] += max(dp[x][0], dp[x][1] - d[min(u, x)]); } } } int main() { //freopen("D:\common_text\code_stream\in.txt","r",stdin); //freopen("D:\common_text\code_stream\out.txt","w",stdout); gbtb; cin >> n; repd(i, 1, n) { cin >> f[i]; } repd(i, 1, n) { cin >> d[i]; } repd(i, 2, n) { if (i & 1) { if (3 * i + 1 <= n) { v[i].push_back(3 * i + 1); v[3 * i + 1].push_back(i); mer(i, 3 * i + 1); } } else { v[i].push_back(i / 2); v[i / 2].push_back(i); mer(i, i / 2); } } repd(i, 1, n) { if (pre[i] == 0) { v[0].push_back(i); v[i].push_back(0); } } dfs(0, 0); cout << max(dp[0][0], dp[0][1]) << endl; return 0; } inline void getInt(int* p) { char ch; do { ch = getchar(); } while (ch == ' ' || ch == ' '); if (ch == '-') { *p = -(getchar() - '0'); while ((ch = getchar()) >= '0' && ch <= '9') { *p = *p * 10 - ch + '0'; } } else { *p = ch - '0'; while ((ch = getchar()) >= '0' && ch <= '9') { *p = *p * 10 + ch - '0'; } } }