仅用于(也专用于)写题解,其他的去 kewth.github.io 玩啊~~

题解 CF840C 【On the Bench】

2020-03-16 11:38:20


复杂度更优秀的做法。

把一个数的平方因子都去掉,那么两个数的乘积为完全平方数当且仅当两个数相等,问题转换为给定一个序列,求这个序列没有相邻元素相同的排列数。

考虑容斥,设 $f_i$ 表示有至少 $i$ 对相等的数相邻的方案数, 那么答案就是 $\sum_{i=0}^n f_i (-1)^{n - i}$ 。

求 $f_i$ 就是钦定 $i$ 对数,把它们绑在一起看作整体求绑定的方案数再乘个绑定后的排列数 $(n - i)!$ 。

而值不同数之间是不会绑定的,也就是说不同的值之间互不影响,绑定方案数可以乘法原理。

那么对于值为 $x$ 的数求出 $g_{x,i}$ 表示在这些数中绑定 $i$ 对数的方案数,总的绑定方案数就是就是所有 $g_x$ 的卷积合并。

可惜模数 $10^9+7$ ,暴力卷积即可,复杂度 $O(n^2)$ ,理论上可以分治优化卷积做到 $O(nlog^2n)$ 。

参考实现:

(我的实现中 $O(n\sqrt{V})$ 暴力找的完全平方数)

#include <cstdio>
#include <cmath>
#include <algorithm>
#define debug(...) fprintf(stderr, __VA_ARGS__)

typedef long long ll;
struct {
    inline operator int () { int x; return scanf("%d", &x), x; }
    inline operator ll () { ll x; return scanf("%lld", &x), x; }
} read;

const int maxn = 5050, mod = 1000000007;
int a[maxn];
int b[maxn], bp;
ll f[maxn], g[maxn], h[maxn];
ll fac[maxn], ifac[maxn];

ll power(ll x, int k) {
    if(k < 0) k += mod - 1;
    ll res = 1;
    while(k) {
        if(k & 1) (res *= x) %= mod;
        (x *= x) %= mod;
        k >>= 1;
    }
    return res;
}

int main() {
    int n = read;
    int max = int(sqrt(1000000000));
    for(int i = 1; i <= n; i ++) {
        a[i] = read;
        for(int x = max; x; x --)
            if(!(a[i] % (x * x))) {
                a[i] /= x * x;
                break;
            }
    }

    std::sort(a + 1, a + n + 1);
    for(int i = 1; i <= n; i ++)
        if(a[i] != a[i - 1])
            b[++ bp] = 1;
        else
            ++ b[bp];

    fac[0] = 1;
    for(int i = 1; i <= n; i ++)
        fac[i] = fac[i - 1] * i % mod;
    ifac[n] = power(fac[n], -1);
    for(int i = n; i; i --)
        ifac[i - 1] = ifac[i] * i % mod;

    int lim = 0;
    f[0] = 1;

    for(int i = 1; i <= bp; i ++) {
        ll coe = 1;
        for(int j = b[i]; j; j --) {
            g[b[i] - j] = coe * ifac[b[i] - j] % mod;
            (coe *= j * (j - 1)) %= mod;
        }
        for(int j = 0; j <= lim + b[i]; j ++)
            h[j] = 0;
        for(int j = 0; j <= lim; j ++)
            for(int k = 0; k < b[i]; k ++)
                (h[j + k] += f[j] * g[k]) %= mod;
        lim += b[i] - 1;
        for(int j = 0; j <= lim; j ++)
            f[j] = h[j];
    }

    for(int j = 0; j <= lim; j ++)
        (f[j] *= fac[n - j]) %= mod;

    ll ans = 0;
    for(int j = 0; j <= lim; j ++)
        ans += j & 1 ? - f[j] : f[j];
    ans %= mod;
    if(ans < 0) ans += mod;

    printf("%lld\n", ans);
}