本文移植自个人的原博客(原文

题目链接

Shuffle and Swap

题目描述

题目描述

题解

题目大意为将$A$与$B$中所有为一处的下标分别加入$a$,$b$集合,求有多少种$a$,$b$的排列满足在依次交换$A_{a_i}$,$A{b_i}$后满足$A=B$。

交换操作实际为将$A$中的$1$调整匹配到$B$中$1$的位置,所以当不存在$A_i=B_i=1$时,不论如何打乱$a$,$b$的排列,在操作后总有$A=B$。当存在$A_i=B_i=1$时,$Ai$的交换有两种情况:

$A_i$所对应的$B_i$为$1$。

$A_i$所对应的$B_i$为$0$。

在进行第一种交换后,如果与$0$进行交换后,需要在操作结束后重新交换回来。那么我们不妨从$A_i$到$B_i$连一条边,易知每一个点出度入度均不超过$1$,且根据以上结论,设$x$为$A_i=B_i=1$的个数,$y$为$A_i!=B_i \ and \ A_i=1$的个数,可知该生成图共有$y$条链,且$x$个点可以作为变换的中转节点加入$y$条链中,那么我们可设$f(i,j)$表示在前$i$条链加入$j$个节点点的方案数,转移方程如下:

f[0][0]=1;
for(int i=1;i<=y;++i)
for(int j=0;j<=x;++j)
for(int k=0;k<=j;++k)
(f[i][j]+=f[i-1][j-k]*rev[k+1])%=mod;(rev[k+1]为k+1阶乘的逆元)

该DP时间复杂度为$O(n^3)$,期望得分为$1200$。
考虑进行优化,观察可知$f[i−1][j−k]×rev[k+1]$为卷积形式,可以用$NTT$进行优化,代码如下。

f[0][0]=1;
for(int i=1;i<=y;++i)
{
for(int j=x+1;j<=x*3;++j)a[j]=b[j]=0;
for(int j=0;j<=x;++j)a[j]=f[(i-1)&1][j],b[j]=rev[j+1];
NTT(a,b,x);
for(int j=0;j<=x;++j)f[i&1][j]=a[j];
}

此处观察易知,该方程实质上就是将$f$数组乘了$y$遍$rev$数组,加上快速幂优化即可通过极限数据,期望得分$1700$。

a[0]=1;
for(int i=0;i<=x;++i)b[i]=rev[i+1];
for(int t=y;t;t>>=1)
{
for(int i=x+1;i<N;++i)a[i]=b[i]=c[i]=d[i]=0;
if(t&1)NTT(a,b,x,0);
NTT(b,b,x,1);
}

完整代码如下

#include<bits/stdc++.h>
#define inv(x) qpow(x,mod-2)
using namespace std;
typedef long long ll;

const int N = 5e4+1;
const ll mod = 998244353;
const ll G = 3;
const double pi = acos(-1);

char a1[N],b1[N];
int x,y;
ll rev[N],fac[N],r[N];
ll a[N],b[N],c[N],d[N];

int qpow(ll a,int b)
{
ll ret=1;
while(b)
{
if(b&1)ret=(ret*a)%mod;
b>>=1;
a=(a*a)%mod;
}
return ret;
}

void pre()
{
ll i;
for(i=1,fac[0]=1;i<N;++i)fac[i]=(fac[i-1]*i)%mod;
for(rev[N-1]=inv(fac[N-1]),i=N-2;i>=0;--i)rev[i]=(rev[i+1]*(i+1))%mod;
}

inline void ntt(ll *a,int f,ll n)
{
for(int i=0;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int i=1,t=1;i<n;i<<=1,++t)
{
ll wn=qpow(G,(mod-1)/(1<<t));
if(f==-1) wn=inv(wn);
for(int j=0;j<n;j+=(i<<1))
for(ll k=0,w=1;k<i;k++,w=w*wn%mod)
{
ll x=a[j+k],y=a[j+k+i]*w%mod;
a[j+k]=(x+y)%mod;a[j+k+i]=(x-y+mod)%mod;
}
}
}

void NTT(ll *a,ll *b,ll len,bool tag)
{
ll n=1,m=0;
while(n<=2*len)n<<=1,m++;
for(int i=0;i<n;++i)r[i]=r[i>>1]>>1|(1&i)<<(m-1);
copy(a,a+x+1,c);copy(b,b+x+1,d);
for(int i=x+1;i<N;++i)a[i]=b[i]=c[i]=d[i]=0;
ntt(c,1,n);ntt(d,1,n);
for(int i=0;i<n;++i)(c[i]*=d[i])%=mod;
ntt(c,-1,n);ll t=inv(n);
for(int i=0;i<=len;++i)a[i]=c[i]*t%mod;
return;
}

/*

1101011011110
0111101011101

*/

int main()
{
pre();
scanf("%s%s",a1+1,b1+1);
int len=strlen(a1+1);
for(int i=1;i<=len;++i)
if(b1[i]=='1'&&a1[i]=='1')x++;
else if(a1[i]=='1')y++;
a[0]=1;
for(int i=0;i<=x;++i)b[i]=rev[i+1];
for(int t=y;t;t>>=1)
{
for(int i=x+1;i<N;++i)a[i]=b[i]=c[i]=d[i]=0;
if(t&1)NTT(a,b,x,0);
NTT(b,b,x,1);
}
ll ans=0;
for(int i=0;i<=x;++i)(ans+=a[i]*fac[x]%mod*fac[y]%mod*fac[x+y])%=mod;
cout<<ans;
}