快速沃尔什变换(FWT)学习笔记

First Post:

Last Update:

Word Count:
1.2k

Read Time:
5 min

在 OI 中,FWT 是用于解决对下标进行位运算卷积问题的一种方法。

引入

给出序列 ,我们想要求出 ,满足

其中 代表按位与,按位或,按位异或中的一种。

直接求是 的,而 FWT 通过构造出一种可逆的变换 ,使得 ,从而快速解决下标位运算卷积。

FWT 的运算

以下设 为按位或, 为按位与, 为按位异或。

按位或

考虑构造 ,则

如果没看明白就把 展开推一遍,然后就能发现这是一样的。

那么我们现在只需要能做到快速求出 且能快速进行逆运算就做完了。

可以发现,这个 ,不就是求子集和吗?

对每个子集求子集和,这不是高维前缀和吗?

考虑逆运算,显然只需要做一遍高维差分即可。

复杂度 ,这里的 指的是维数。如果把 变成序列长度,则复杂度为

按位与

考虑构造 ,可以发现形式和按位或的情况一模一样。

证明也类似按位或,这里不再赘述。

可以发现,和按位或的唯一区别在于之前是求子集和,现在是求超集和。

依旧可以高维前缀和。

复杂度

按位异或

因为还没学明白,所以基本都是抄的 OI Wiki

那么可以得到 有分配律:

考虑构造 ,则

如何快速计算 ?考虑分治。

在当前位为 0 的子数列为 ,在当前位为 1 的子数列为 ,则

其中 表示拼接, 表示对应位置加/减。

逆变换为

复杂度

例题:【模板】快速莫比乌斯/沃尔什变换 (FMT/FWT)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include"bits/stdc++.h"
#define re register
#define int long long
using namespace std;
const int maxn=(1<<17)+10,mod=998244353,inv2=499122177;
int n,S;
int A[maxn],B[maxn];
int a[maxn],b[maxn],c[maxn];
inline void in(){for(re int i=0;i<=S;++i) a[i]=A[i],b[i]=B[i];}
inline void get(){for(re int i=0;i<=S;++i) c[i]=a[i]*b[i]%mod;}
inline void out(){for(re int i=0;i<=S;++i) cout<<c[i]<<" ";cout<<'\n';}
inline void OR(int f[],int op){
for(re int i=0;i<n;++i){
for(re int s=0;s<=S;++s){
if((s>>i)&1) f[s]=((f[s]+f[s^(1<<i)]*op)%mod+mod)%mod;
}
}
}
inline void AND(int f[],int op){
for(re int i=0;i<n;++i){
for(re int s=0;s<=S;++s){
if((s>>i)&1) f[s^(1<<i)]=((f[s^(1<<i)]+f[s]*op)%mod+mod)%mod;
}
}
}
inline void XOR(int f[],int op){
for(re int i=0,len=1;i<n;++i,len<<=1){
for(re int s=0;s<=S;s+=len*2){
for(re int j=0;j<len;++j){
f[s+j]=(f[s+j]+f[s+j+len])%mod;
f[s+j+len]=((f[s+j]-f[s+j+len]*2)%mod+mod)%mod;
f[s+j]=((f[s+j]*op)%mod+mod)%mod;
f[s+j+len]=((f[s+j+len]*op)%mod+mod)%mod;
}
}
}
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("1.in","r",stdin);
freopen("1.out","w",stdout);
#endif
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>n;S=(1<<n)-1;
for(re int i=0;i<=S;++i) cin>>A[i];
for(re int i=0;i<=S;++i) cin>>B[i];
in(),OR(a,1),OR(b,1),get(),OR(c,-1),out();
in(),AND(a,1),AND(b,1),get(),AND(c,-1),out();
in(),XOR(a,1),XOR(b,1),get(),XOR(c,inv2),out();
return 0;
}