Problem Description
给定一棵 nnn 个节点的树,树的每条边权为 wiw_iwi 。
对一个图有以下 444 种限制:
- 图没有自环和重边。
- 图中边权为整数,且不超过 SSS 。
- 该图只有一棵最小生成树。
- 该图的最小生成树为给定的树。
求可行图的个数。
Input
第一行输入一个整数 t(1≤t≤104)t(1 \le t \le 10^4)t(1≤t≤104) ,表示测试组数。
对于每组测试,
第一行输入两个整数 n,S(2≤n≤2×105,1≤S≤109)n,S(2 \le n \le 2 \times 10^5,1 \le S \le 10^9)n,S(2≤n≤2×105,1≤S≤109) ,表示节点个数和边权上界。
接下来 n−1n-1n−1 ,每行输入三个整数 ui,vi,wi(1≤ui,vi≤n,ui≠vi,1≤wi≤S)u_i,v_i,w_i(1 \le u_i,v_i \le n,u_i \not= v_i,1 \le w_i \le S)ui,vi,wi(1≤ui,vi≤n,ui=vi,1≤wi≤S) ,表示 uiu_iui 和 viv_ivi 之间有一条边权为 wiw_iwi 的边。
题目保证 ∑n≤2×105\sum n \le 2 \times 10^5∑n≤2×105 。
Output
输出模 998244353998244353998244353 后的可行图的个数。
Solution
我们考虑 KruskalKruskalKruskal 生成的过程中,会发现当添加了一条权值为 wiw_iwi 的边后,会合并两个连通块得到一个新的连通块 bbb ,而在 bbb 中添加任意边权 wj∈(wi,S]w_j \in (w_i,S]wj∈(wi,S] 后,并不会妨碍最小生成树的唯一性。
那么这题的解法便变成了在做 KruskalKruskalKruskal 每次合并两个连通块 x,yx,yx,y 时,计算能够加边对于 ansansans 的贡献,贡献即为加边方案数 cntcntcnt ,然后令 ans=ans×cntans=ans \times cntans=ans×cnt 即可。但是我们会发现由于之前生成 x,yx,yx,y 连通块时已经计算了他们各自能够提供的方案数贡献,所以我们在每次计算方案数时将能加的边不考虑 x,yx,yx,y 内部生成的,那么新生成的边数就变成了 sizx×sizy−1siz_x \times siz_y - 1sizx×sizy−1,cnt=(S−wi+1)sizx×sizy−1cnt=(S-w_i+1)^{siz_x \times siz_y -1}cnt=(S−wi+1)sizx×sizy−1 ,这样我们就能不重不漏地计算出答案。
最终时间复杂度为 O(nlogn)O(nlogn)O(nlogn) 。
Code
#include <bits/stdc++.h>
#define endl '\n'
using namespace std;
typedef long long ll;
constexpr int mod = 998244353;
ll qmi(ll a, ll k) {
ll res = 1 % mod, t = a % mod;
while (k) {
if (k & 1)res = res * t % mod;
t = t * t % mod;
k >>= 1;
}
return res;
}
void solve() {
int n, S;
cin >> n >> S;
vector<ll>p(n + 1), siz(n + 1, 1); iota(p.begin(), p.end(), 0);
vector<tuple<int, int, int>>edge;
auto leader = [&](int x) {
while (x != p[x])x = p[x] = p[p[x]];
return x;
};
for (int i = 1; i <= n - 1; i++) {
int u, v, w;
cin >> u >> v >> w;
edge.emplace_back(w, u, v);
}
sort(edge.begin(), edge.end());
ll ans = 1;
for (auto [w, u, v] : edge) {
u = leader(u), v = leader(v);
ans = ans * qmi(S - w + 1, siz[u] * siz[v] - 1) % mod;
siz[u] += siz[v];
p[v] = u;
}
cout << ans << endl;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int t;
cin >> t;
while (t--)solve();
}