Weak Pair HDU - 5877
You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k
.
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, N and k, respectively.
The second line contains N space-separated integers, denoting a1 to aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1
2 3
1 2
1 2
Sample Output
1
题意:
一颗树上,有n个节点,给出每个节点的权值。另外给出一个值k,问有多少对节点满足:
a[u]×a[v]≤ka[u]×a[v]≤k
u是v节点的祖先(u≠v)u是v节点的祖先(u≠v)
分析:
我们先考虑,红色节点为V,那么U为V的所有祖先节点中权值小于ka[v]ka[v]都满足。
所以我们思路不是固定一个祖先去往下找他的孩子是否满足条件,而是走到一个节点,去看他的祖先是否满足条件,这样因为dfs的性质,你走到一个节点,它的祖先必然已经走过了(否则到不了这个点),因此我们可以用数组记录祖先的情况
因为每次遍历所有祖先寻找小于等于ka[v]ka[v]的 必然超时,所以想到利用树状数组,计算值ka[v]ka[v]之前的和即可。这里需要进行离散化处理。
我们先把每个节点的权值a[i]存下来排个序,这样每次到一个节点v的时候二分查找第一个大于ka[v]ka[v]的位置,这样之前的部分就都是小于等于的了使用upper_bound即可,为什么不用lower_bound,因为lower_bound查找的是第一个大于等于的位置,所以如果找到的位置是等于的,后面还有可能有等于的,这样就不好确定了,很麻烦
那用upper_bound的时候找到的位置是第一个大于的,那么求和的时候不就多求了一个大于的不满足条件的点吗,解决这个问题是容易的,我们只需要更新的时候并不是更新原来位置,而是更新原来位置+1即可,这样用upper_bound找到的位置就是恰好所有满足条件的了,然后求和就可以了
也就是用c数组记录祖先情况,c数组初始化为0,说明还没有走任何点,每走到一个点,就更新加入进去,更新的位置为原来节点权值数组从小到大排好序的位置+1,注意退出一个节点的时候,这个点的影响就没有了,要更新删除这个点的影响
所以每次需要寻找两个下标
1 pos : 大于ka[v]ka[v]对应的第一个下标
2.posthis : 第一个大于a[v]a[v]的下标
ans += getsum(pos); // 即找到了当前节点为a[v],前面祖先节点有几个小于等于ka[v]ka[v]的值
add(posthis,1); //当前节点是下面节点的父节点,为了防止上面求和会把大于的也算上,这样我把更新的区间位置整体后移一个就可以啦
add(posthis,-1); //退出同理
code:
#include <bits/stdc++.h>
using namespace std;
#define N 200005
typedef long long ll;
int c[N];
int lowbit(int x){
return x & (-x);
}
void update(int i,int v){
for(; i < N; c[i] += v, i += lowbit(i));
}
int getsum(int i){
int ans = 0;
for(; i > 0; ans += c[i], i -= lowbit(i));
return ans;
}
vector<int> tree[N];
int cnt;
ll ans = 0;
ll n,k;
int f[N];
int a[N];
map<int,int> vis;
vector<int> vec;
void dfs(int rt){
int len = tree[rt].size();
int pos = upper_bound(vec.begin(),vec.end(),k / a[rt]) - vec.begin();
int posthis = upper_bound(vec.begin(),vec.end(),a[rt]) - vec.begin();
ans += getsum(pos);
update(posthis,1);
for(int i = 0; i < len; i++){
dfs(tree[rt][i]);
}
update(posthis,-1);
}
void init(){
cnt = 1;
ans = 0;
vec.clear();
vis.clear();
for(int i = 0; i < 200003; i++){
tree[i].clear();
f[i] = 0;
c[i] = 0;
}
}
int main(){
int t;
scanf("%d",&t);
while(t--){
init();
scanf("%lld%lld",&n,&k);
for(int i = 1; i <= n; i++){
scanf("%d",&a[i]);
if(!vis[a[i]]){
vec.push_back(a[i]);
vis[a[i]] = 1;
}
}
sort(vec.begin(),vec.end());
for(int i = 1; i < n; i++){
int u,v;
scanf("%d%d",&u,&v);
tree[u].push_back(v);
f[v] = 1;
}
for(int i = 1; i <= n; i++){
if(!f[i]){
dfs(i);
break;
}
}
printf("%lld\n",ans);
}
return 0;
}