CF2075E E. XOR Matrix题解

E. XOR Matrix

题目翻译

​ 题目给我们n,m,A,B。我们需要确认有多少种不同的数组(a, b)对是好的,a,b均是数组,a中的元素在0到A之间,b中的元素在0到B之间。数组中a中的任意元素异或上b中的任意元素得到的结果最多只能有2种。输出多少种,在模998244353的意义下。

思路

​ 首先我们发现可以分开来考虑相同结果为1种和相同结果为2种的情况,对于结果只有一种的数,一定是数组a中为i,数组b中为j,共两种不同的数。对于相同的结果为2的情况一定是数组a为i,数组b为j,k,或者反过来。结果为2还有可能是(i,j)for a, (k, l) for b,此时还要满足i ^ j ^ k ^ l == 0,我们发现只有最后一种情况我们比较难以完成计数,其余的情况我们都可以通过组合数学的公式快速推到出来。

​ 考虑如何计算四个数字i,j,k,l满足异或和为0,且i != j 并且k != l,同时i,j属于[0, A] ,k,l属于[0, B]。我们发现我们可以拆成一位一位的看,每位中为1的位的个数一定要是偶数,这样的话我们可以使用类似数位dp的方式来完成计数,标记每个数字是否处于上限和i是否等于j,k是否等于l即可。注意我们dp出来的结果在后面考虑实际插入时会偏大,我们考虑插入的实际意义实际上是真实的4倍,需要除以4(例如(10,10)(01,10)(10,01)(01,01), 会被我们每种01,10考虑到)。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
constexpr ll mod = 998244353;

ll qmi(ll a, ll k, ll m){ //求a^k mod m
a %= m;
ll res = 1 % m;
while (k)
{
if (k&1) res = res * a % m;//指数k为1的位乘上a^(1<<x)mod m
a = a * a % m;//每一项是前一项的平方模m
k >>= 1;
}
return res;
}

void solve(){
ll n, m, A, B;
std::cin >> n >> m >> A >> B;
//std::cerr << n << space << m << space << A << space << B << endl;

ll dp[31][2][2][2][2][2][2];
memset(dp, -1, sizeof(dp));
std::function<ll(int, int, int, int, int, int, int)> dfs = [&](int pos, int limit1, int limit2, int same1, int limit3, int limit4, int same2) -> ll{
if(pos < 0){
return ((!same1) && (!same2));
}
//std::cerr << pos << space << limit1 << limit2 << limit3 << limit4
if(dp[pos][limit1][limit2][same1][limit3][limit4][same2] != -1)
return dp[pos][limit1][limit2][same1][limit3][limit4][same2];
ll ans = 0;
for(int i = 0; i <= 1; i++){
for(int j = 0; j <= 1; j++){
for(int k = 0; k <= 1; k++){
for(int l = 0; l <= 1; l++){
if((i + j + k + l) % 2 == 0){
int up1 = (A >> pos) & 1;
int up2 = (B >> pos) & 1;
if(i > up1 && limit1)
continue;
if(j > up1 && limit2)
continue;
if(k > up2 && limit3)
continue;
if(l > up2 && limit4)
continue;


ans = (ans + dfs(pos - 1, limit1 && (i == up1), limit2 && (j == up1), (i == j) & same1,
limit3 && k == up2, limit4 && l == up2, (k == l) & same2)) % mod;
//std::cerr << i << j << k << l << endl;
}
}
}
}
}

//std::cerr << pos << space << limit1 << limit2 << limit3 << limit4 << space << ans << endl;
return dp[pos][limit1][limit2][same1][limit3][limit4][same2] = ans;
};
ll res = dfs(30, true, true, true, true, true, true);
res = res * qmi(4, mod - 2, mod) % mod;

ll ans = 0;
ans = (res *(qmi(2, n, mod) - 2) % mod * (qmi(2, m, mod) - 2) % mod + ans) % mod;

ans = (ans + (A + 1) * (B + 1) % mod) % mod;

ans = (ans + (A + 1) * (B + 1) % mod * B % mod * (qmi(2, m, mod) - 2) % mod * qmi(2, mod - 2, mod) % mod) % mod;
ans = (ans + (B + 1) * (A + 1) % mod * A % mod * (qmi(2, n, mod) - 2) % mod * qmi(2, mod - 2, mod) % mod) % mod;

ans = (ans + mod) % mod;
std::cout << ans << endl;
}