本文移植自个人的原博客(原文

题目链接

Dating

题目描述

题目描述

题解

树上莫队裸题,题目大意为给$q$次询问,询问两点之间的链中权值相同的男生个数与女生个数的乘积,并对所有权值求和。

对于序列中的莫队算法来说,需要对于先对于左区间所在的块为第一关键字,右区间为第二关键字进行查询,树上莫队也不例外。

首先应该用DFS序表示该树。对于子树的查询,在DFS序中为连续的一段,其求解与序列中的莫队相同,而对于链的查询,其在序列中不为连续的一段,所以需要特殊的处理:

  • $LCA$为两端点中的任意一个,则应插入$[st[a],st[b]]$区间,容易发现除了链上的点被计算了一遍,其他点都被计算了$0$遍或$2$遍,可以将偶数次的插入视为删除即可。

  • $LCA$不为两节点中的任意一个,此处应插入$[ed[a],st[b]]$区间,计算方法与上文相同,但需要注意此时利用的是DFS序而不是欧拉序,$LCA$实际上没有参与运算,加上即可,单个不在区间中节点的插入对时间复杂度没有影响。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5+1;
const int MXB =19;
int head[N],cnt,n,Q,id[N],f1[N],f2[N],d[N],f[N][MXB],st[N],ed[N],dfx[N*2],totw,block;
int tong[2][N];
ll ans[N];
bool inq[N];
//int th[N];
struct nd{int ne,to;}e[N*2];
struct questions{int id,m,l,r,b;}q[N*2];

bool cmp(questions a,questions b)
{
if(a.b!=b.b)return a.b<b.b;
return a.r>b.r;
}

void in(int x,int y){e[++cnt].to=y;e[cnt].ne=head[x];head[x]=cnt;}

void dfs(int x,int fa)
{
st[x]=++totw;dfx[totw]=x;
for(int i=head[x];i;i=e[i].ne)
if(e[i].to!=fa)
{
int y=e[i].to;
d[y]=d[x]+1;
f[y][0]=x;
dfs(y,x);
}
ed[x]=++totw;dfx[totw]=x;
}

int lca(int x,int y)
{
if(d[x]<d[y])swap(x,y);
for(int i=MXB-1;i>=0;--i)
if(d[f[x][i]]>=d[y])x=f[x][i];
if(x==y)return x;
for(int i=MXB-1;i>=0;--i)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}

void Pre()
{
d[1]=1;dfs(1,-1);
for(int i=1;i<MXB;++i)
for(int t=1;t<=n;++t)
f[t][i]=f[f[t][i-1]][i-1];
}
ll ret=0;
void insert(int x)
{
inq[x]^=1;
if(inq[x])
{
tong[id[x]][f1[x]]++;
ret+=tong[id[x]^1][f1[x]];
// th[x]++;
}
else
{
tong[id[x]][f1[x]]--;
ret-=tong[id[x]^1][f1[x]];
// th[x]--;
}
}


int main()
{
// freopen("1.txt","r",stdin);
scanf("%d",&n);
block=sqrt(n)+1;
for(int i=1;i<=n;++i)scanf("%d",id+i);
for(int i=1;i<=n;++i)scanf("%d",f1+i),f2[i]=f1[i];
sort(f2+1,f2+n+1);
int nn=unique(f2+1,f2+n+1)-f2-1;
for(int i=1;i<=n;++i)f1[i]=lower_bound(f2+1,f2+nn+1,f1[i])-f2;
// for(int i=1;i<=n;++i)cout<<id[i]<<" ";cout<<endl;
// for(int i=1;i<=n;++i)cout<<f1[i]<<" ";cout<<endl;
for(int i=1,x,y;i<n;++i)scanf("%d%d",&x,&y),in(x,y),in(y,x);
Pre();
scanf("%d",&Q);
for(int i=1,a,b;i<=Q;++i)
{
scanf("%d%d",&a,&b);
q[i].id=i;
int p=lca(a,b);
if(st[a]>st[b])swap(a,b);
if(p==a)q[i].l=st[a],q[i].r=st[b],q[i].m=0;
else q[i].l=ed[a],q[i].r=st[b],q[i].m=p;
q[i].b=q[i].l/block+1;
}
int L=1,R=0;
sort(q+1,q+Q+1,cmp);
for(int i=1;i<=Q;++i)
{
int l=q[i].l,r=q[i].r;
while(L<l)insert(dfx[L]),L++;
while(L>l)L--,insert(dfx[L]);
while(R>r)insert(dfx[R]),R--;
while(R<r)R++,insert(dfx[R]);
if(q[i].m)insert(q[i].m);
// for(int t=1;t<=n;++t)cout<<th[t]<<" ";cout<<endl;
ans[q[i].id]=ret;
if(q[i].m)insert(q[i].m);
// cout<<endl<<endl<<endl;
}
for(int i=1;i<=Q;++i)printf("%I64d\n",ans[i]);
}