Weak Pair

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 262144/262144 K (Java/Others)
Problem Description
You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
  (1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
  (2) au×avk.

Can you find the number of weak pairs in the tree?


There are multiple cases in the data set.
  The first line of input contains an integer T denoting number of test cases.
  For each case, the first line contains two space-separated integers, N and k, respectively.
  The second line contains N space-separated integers, denoting a1 to aN.
  Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u and v , where node u is the parent of node v.



For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.


Sample Input
12 31 21 2


Sample Output
#include <iostream>#include <cstdio>#include <cstdlib>#include <cmath>#include <algorithm>#include <climits>#include <cstring>#include <string>#include <set>#include <map>#include <queue>#include <stack>#include <vector>#include <list>#define rep(i,m,n) for(i=m;i<=n;i++)#define rsp(it,s) for(set<int>::iterator it=s.begin();it!=s.end();it++)#define mod 1000000007#define inf 0x3f3f3f3f#define vi vector<int>#define pb push_back#define mp make_pair#define fi first#define se second#define ll long long#define pi acos(-1.0)#define pii pair<int,int>#define Lson L, mid, rt<<1#define Rson mid+1, R, rt<<1|1const int maxn=2e5+10;using namespace std;ll gcd(ll p,ll q){return q==0?p:gcd(q,p%q);}ll qpow(ll p,ll q){ll f=1;while(q){if(q&1)f=f*p;p=p*p;q>>=1;}return f;}int n,m,k,t,h[maxn],tot,q[maxn],fa[maxn],num;ll ans,a[maxn],b[maxn],c[maxn];struct node{    int to,nxt;}e[maxn];void add(int x,int y){    tot++;    e[tot].to=y;    e[tot].nxt=h[x];    h[x]=tot;}void gao(int x,int y){    for(int i=x;i<=num+5;i+=(i&(-i)))        q[i]+=y;}int get(int x){    int ret=0;    for(int i=x;i;i-=(i&(-i)))        ret+=q[i];    return ret;}void dfs(int now){    ans+=get(num+5)-get(a[now]-1);    gao(b[now],1);    for(int i=h[now];i;i=e[i].nxt)    {        dfs(e[i].to);    }    gao(b[now],-1);}int main(){    int i,j;    scanf("%d",&t);    while(t--)    {        ans=0;        tot=0;        j=0;        ll p;        memset(h,0,sizeof h);        memset(fa,0,sizeof fa);        memset(q,0,sizeof q);        scanf("%d%lld",&n,&p);        rep(i,1,n){            scanf("%lld",&a[i]);            if(a[i]==0)b[i]=1e19;            else b[i]=p/a[i];            c[j++]=a[i],c[j++]=b[i];        }        sort(c,c+j);        num=unique(c,c+j)-c;        rep(i,1,n)a[i]=lower_bound(c,c+num,a[i])-c+2,b[i]=lower_bound(c,c+num,b[i])-c+2;        rep(i,1,n-1)        {            int x,y;            scanf("%d%d",&x,&y);            add(x,y);            fa[y]=x;        }        rep(i,1,n)if(!fa[i])dfs(i);        printf("%lld\n",ans);    }    //system("Pause");    return 0;}

