Problem Description
给出一棵 nnn 个节点的有边权的无根树,每个点为白色或黑色,每个点都能翻转颜色,但是需要花费一定价格。
我们定义一棵树收益为: ∑x∈V1∑y∈V2val(x,y)\sum\limits_{x \in V_1} \sum\limits_{y \in V2} val(x,y)x∈V1∑y∈V2∑val(x,y) , V1V_1V1 表示白点集合, V2V_2V2 表示黑点集合, val(x,y)val(x,y)val(x,y) 表示 xxx 到 yyy 最短路径上的最大边权。
求最大收益。
Input
第一行输入一个整数 nnn ,表示节点个数。
第二行输入 nnn 个整数 aia_iai (0≤ai≤1)(0 \le a_i \le 1)(0≤ai≤1) ,表示第 iii 个节点的颜色, 000 表示白色, 111 表示黑色。
第三行输入 nnn 个整数 costicost_icosti (0≤costi≤109)(0 \le cost_i \le 10^9)(0≤costi≤109) ,表示第 iii 个节点翻转所需的费用。
接下来 n−1n-1n−1 行每行输入三个整数 ui,vi,wiu_i,v_i,w_iui,vi,wi (1≤ui,vi,wi≤n)(1 \le u_i,v_i,w_i \le n)(1≤ui,vi,wi≤n) ,表示 uiu_iui 和 viv_ivi 之间有一条边权为 wiw_iwi 的边。
Output
输出最大收益。
Solution
观察到 val(x,y)val(x,y)val(x,y) 表示 xxx 到 yyy 最短路径上的最大边权,我们很显然地就能联想到 KruskalKruskalKruskal 重构树的构造和性质,两个点的路径最大边权即是重构树上两点的 LCALCALCA ,且构造重构树时是将两个集合合并,合并时的对应边权是两个集合路径上的最大边权。
这样我们可以将按边权为关键值排序后,从小到大枚举每条边的对答案的贡献。
那么问题来了如何求这个贡献?
此时我们便考虑 dpdpdp ,设计 dpu,idp_{u,i}dpu,i 表示当前集合 uuu 中白点个数为 iii 的最大值。当两个集合合并时,我们先去枚举总的白点个数,然后再枚举其中一个集合的白点个数,最终我们可以得到转移方程为:dpu,i=max(dpu,i,dpu,l+dpu,r+l∗(m−r)+(n−l)∗r)dp_{u,i}=max(dp_{u,i},dp_{u,l}+dp_{u,r}+l*(m-r)+(n-l)*r)dpu,i=max(dpu,i,dpu,l+dpu,r+l∗(m−r)+(n−l)∗r)。
而对于翻转消费 costicost_icosti ,我们可以在初始化时候,将 dpi,a[i]=−costidp_{i,a[i]}=-cost_idpi,a[i]=−costi , dpi,a[i]⊕1=0dp_{i,a[i] \oplus 1}=0dpi,a[i]⊕1=0 即可,即表示成有 0,10,10,1 个白点的收益。
最后我们考虑复杂度,整体集合合并为 O(n)O(n)O(n),dpdpdp 转移的上界是 O(n2)O(n^2)O(n2) ,极端情况下复杂度会达到 O(n3)O(n^3)O(n3) ,这是我们无法接受的。
但是我们仔细考虑到 dpdpdp 转移的第二层是枚举其中一个集合的大小,此时我们考虑到启发式合并,每次枚举小的集合,这样最终的复杂度便是O(n2logn)O(n^2logn)O(n2logn) ,能够通过此题。
Code
#include <bits/stdc++.h>
#define endl '\n'
using namespace std;
typedef long long ll;
constexpr int N = 3010;
int p[N];
int leader(int x) {
while (x != p[x])x = p[x] = p[p[x]];
return x;
}
vector<ll>dp[N];
void merge(int u, int v, int w) {
u = leader(u), v = leader(v);
if (dp[u].size() > dp[v].size())swap(u, v);
int n = dp[u].size() - 1, m = dp[v].size() - 1;
vector<ll>temp(n + m + 1, -2e18);
for (int i = 0; i <= n + m; i++) {//枚举总的白点个数
for (int j = max(0, i - m); j <= min(i, n); j++) {//枚举u中的白点个数
//max(0,i-m)是由于i-j<=m,min(i,n)是有由于i<=n
int l = j, r = i - j;
int paircnt = l * (m - r) + (n - l) * r;
temp[i] = max(temp[i], dp[u][l] + dp[v][r] + 1ll * paircnt * w);
}
}
dp[u] = temp;
p[v] = u;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n;
cin >> n;
for (int i = 1 ; i <= n; i++) p[i] = i, dp[i].resize(2);
vector<int>a(n + 1), cost(n + 1);
for (int i = 1; i <= n; i++)cin >> a[i];
for (int i = 1; i <= n; i++)cin >> cost[i];
for (int i = 1; i <= n; i++)dp[i][a[i]] = -cost[i];
vector<tuple<int, int, int>>edge;
for (int i = 1; i <= n - 1; i++) {
int u, v, w;
cin >> u >> v >> w;
edge.emplace_back(w, u, v);
}
sort(edge.begin(), edge.end());
for (auto [w, u, v] : edge) merge(u, v, w);
int root = leader(1);
ll ans = 0;
for (auto x : dp[root])ans = max(ans, x);
cout << ans << endl;
}