仅用于(也专用于)写题解,其他的去 kewth.github.io 玩啊~~

题解 P6730 【[WC2020] 猜数游戏】

2020-08-07 19:11:55


update: 加了一点订正。

感谢这题送我金牌

讲评的时候出题人提到这题的数据范围可以更大,感觉很有趣,就去想了想,于是搞了个看上去可行的做法,不过由于没有地方测试,我并不能完全保证这个做法正确,写这篇题解只是为了提供一个思路,如果我有什么弄错的地方还请各位指正。

不过光是口胡肯定也不行,所以我也写了代码,能通过这题(并且很快,交的时候是洛谷第二,loj 第一),如果有人也写了加强版的话可不可以把程序给我对拍一下啊 /kel 。

可能不会讲的很清楚,如果只是为了弄懂原题的话本题解可能不会带来很大帮助。

注意,以下默认 $n\le10^5, p\le10^{14}$ ,我的做法时间复杂度是 $O(nk\log p + d(\varphi(p))k)$ 的,其中 $d$ 表示约数个数函数,$k$ 表示 $p$ 的质因子个数。

题意转换

给定模数 $p$ ,定义一个数 $x$ 的导出集合为 $F(x) = \{y|x^k \equiv y \pmod{p}\}$ 。对于一个非空集合 $S$ ,称它的非空子集 $T$ 是优秀的,当且仅当 $T$ 中元素的导出集合的并集是 $S$ 的父集。$S$ 的权值定义为其最小的优秀子集的大小。

给定大小为 $n$ 的一个数集 $S$ ,求所有子集的权值和。另外,保证 $p=q^k$ ,其中 $q$ 是奇质数。

集合关系

先考虑一个关键的问题,如何判断 $y \in F(x)$ 是否成立。一个朴素的做法是 BSGS ,单次复杂度 $O(\sqrt{p})$ ,显然代价太大不能接受。

但是注意到 $p$ 是一定有原根 $g$ 的,那么如果 $q|x \lor q|y$ 可以直接枚举 $k$ 暴力判断,否则一定可以表示为 $x = g^a, y = g^b$ 。于是 $y \in F(x)$ 等价于关于 $k$ 的方程 $g^{ak} \equiv g^b$ 有解。该方程等价于 $ak \equiv b \pmod{\varphi(p)}$ ,显然有解的充要条件是

$$\gcd(a, \varphi(p)) | \gcd(b, \varphi(p))$$

注意到事实上 $\mathrm{ord}_p x = \frac{\mathrm{lcm}(\varphi(p), a)}{a} = \frac{\varphi(p)}{\gcd(a, \varphi(p))}$ ,因此上式还等价于

$$\mathrm{ord}_p y | \mathrm{ord}_p x$$

也就是说只需要求出 $S$ 每个数模 $p$ 意义下的阶,就可以简单判断这个集合关系了。求阶可以用试除法 + 快速幂,试除只需要枚举 $\varphi(p)$ 的每个质因子,复杂度 $O(k\log p)$ 。

统计答案

如果将 $y \in F(x)$ 这个关系连边 $x \rightarrow y$ ,那么一个点集的权值就是最小点覆盖。观察这个图,由于连边的唯一判断依据是阶,可以发现阶相同的数连成了一个团,把这个团看做一个整体的点的话,整张图就是一张 DAG 。那么算一个给定图的权值是容易的,就是没有入度的团的数量(注意这张图的连边有传递性,也就是分层后层数不超过 2)。

现在要统计所有子集的权值和,对于每个团单独算贡献,那么一个子集 $T$ 中该团产生贡献(作为没有入度的团)当且仅当

  1. 该团中至少一个点在 $T$ 中
  2. 连向该团的其他团的任何点都不在 $T$ 中。

设一个团大小为 $x$ ,有 $y$ 个不在该团的点连向该团,那么综上一个团的贡献就是 $(2^x-1)2^{n-y-x}$ 。

建图

可以发现尽管我们知道了连边方式,但是 $n$ 很大的时候显式地连边显然不行,不过我们并不关注具体连边,只需要关注每个点的入度。

统计一个阶为 $x$ 的团的入度的数量也就是统计有多少个数的阶是 $x$ 的倍数。目前我没有想到很好的解决方案,不过可以用高维前缀和做这个经典问题,复杂度 $O(d(\varphi(p))k)$ 。

细节

$p \le 10^{14}$ 的时候算快速幂可能要用到快速乘,我偷懒用了 __int128

高位前缀和的部分存储用 std::map 会带来额外的时间开销,可以用类似杜教筛的 trick 把下标分类讨论,空间复杂度 $O(\sqrt{p})$ 。

参考实现

再次提醒,因为我也没有测过,不保证在更大的数据范围不出错。

#include <cstdio>
#include <cmath>
#include <algorithm>
#include <vector>
#include <map>
#define debug(...) fprintf(stderr, __VA_ARGS__)

typedef long long ll;

static struct {
    inline operator int () { int x; return scanf("%d", &x), x; }
    inline operator ll () { ll x; return scanf("%lld", &x), x; }
} read;

ll mul (ll x, ll y, ll mod) {
    return __int128(x) * y % mod;
}

inline ll power (ll x, ll k, ll mod) {
    if (k < 0) k += mod - 1;
    ll res = 1;
    while (k) {
        if (k & 1) res = mul(res, x, mod);
        x = mul(x, x, mod);
        k >>= 1;
    }
    return res;
}

ll minp (ll n) {
    for (ll d = 2; d * d <= n; d ++)
        if (n % d == 0)
            return d;
    return n;
}

ll phi;
int sq;
std::vector<ll> v, V;
std::map<ll, int> S;
const int maxs = 20000005;
int G[maxs], GS[maxs];

void getd (ll n) {
    sq = int(std::sqrt(n));
    for (int d = 1; d <= sq; d ++)
        if (n % d == 0)
            V.push_back(n / d);
    for (int d = 1ll * sq * sq == n ? sq - 1 : sq; d; d --)
        if (n % d == 0)
            V.push_back(d);
    for (int d = 2; d <= n / d; d ++)
        if (n % d == 0) {
            v.push_back(d);
            while (n % d == 0) n /= d;
        }
    if (n > 1)
        v.push_back(n);
}

inline int id (ll x) {
    return int(x <= sq ? x : sq + phi / x);
}

const int maxn = 100005;
ll po[maxn];

int main () {
    int n = read;
    ll p = read, q = minp(p);
    phi = p / q * (q - 1);
    getd(phi);

    for (int i = 1; i <= n; i ++) {
        ll x = read, y = phi;
        if (x % q == 0) S[x] = 0;
        else {
            for (ll d : v)
            while (y % d == 0 and power(x, y / d, p) == 1)
                y /= d;
            ++ G[id(y)];
            ++ GS[id(y)];
        }
    }

    for (auto P : S) {
        ll x = P.first;
        while (x) {
            if (S.count(x)) ++ S[x];
            x = mul(x, P.first, p);
        }
    }

    for (ll d : v)
        for (ll x : V)
            if (phi / x % d == 0)
                GS[id(x)] += GS[id(x * d)];

    const int mod = 998244353;
    po[0] = 1;
    for (int i = 1; i <= n; i ++) po[i] = (po[i - 1] << 1) % mod;
    ll ans = 0;
    for (auto P : S)
        ans += po[n - P.second];
    ans %= mod;
    for (ll x : V)
        (ans += (po[G[id(x)]] - 1) * po[n - GS[id(x)]]) %= mod;
    printf("%lld\n", ans);
}