仅用于(也专用于)写题解,其他的去 kewth.github.io 玩啊~~

题解 P4516 【[JSOI2018]潜入行动】

2019-12-30 21:25:50


大家都是四维 DP $f[u][k][0/1][0/1]$ ,这题难就难在上下界优化的复杂度分析和转移。 上下界优化的复杂度分析就不细讲了,关键点在于进行 $O(k^2)$ 的背包转移的次数的是 $O(\frac{n}{k})$ 的。

而转移我想说这么多 01 写起来不累嘛。。。

事实上对于这个 01 可以用变量表示,假设当前合并两个背包 $f[u][a][p1][q1]$ 和 $f[v][b][p2][q2]$ ,
其中 $v$ 是 $u$ 的儿子。

考虑合并后的 $f[u][a+b][p3][q3]$ ,$p3$ 和 $q3$ 分别会是什么,以及这个合并什么时候合法。

$p3$ 是合并后点 $u$ 是否安装监听器,这显然不受 $v$ 影响,也就是说 $p3=p1$ 。

$q3$ 是合并后点 $u$ 是否被监听,有两种情况:$u$ 之前已经被监听,$u$ 现在被 $v$ 监听。即:$q3=q1|p2$ 。

但是不是所有状态都能合并的,可以合并当且仅当 $v$ 被监听,
$v$ 被监听同样有两种情况:$v$ 之前被监听,$v$ 现在被 $u$ 监听。
也就是说当且仅当 $q2|p1=1$ 时这两个状态可以合并。

于是得到一个清爽的 4 个 for 转移:

for(int p1 = 0; p1 < 2; p1 ++)
    for(int q1 = 0; q1 < 2; q1 ++)
        for(int p2 = 0; p2 < 2; p2 ++)
            for(int q2 = 0; q2 < 2; q2 ++)
                if(q2 | p1)
                    (newf_u[a + b][p1][q1 | p2] +=
                            f[u][a][p1][q1] * f[v][b][p2][q2]) %= mod;

这样不容易写错,如果全写 01 手动枚举所有情况,很容易漏算或者写错一个 01 之类的,
而且很难调试。

完整实现:

#include <cstdio>
#include <vector>

typedef long long ll;

const int maxn = 100005, maxk = 104, mod = 1000000007;
int f[maxn][maxk][2][2];
int g[maxk][2][2];
int lim[maxn];
std::vector<int> G[maxn];

inline void __a(int &x) { if(x >= mod) x -= mod; }

int N, K;
void dp(int u, int fa) {
    f[u][0][0][0] = 1;
    f[u][1][1][0] = 1;
    lim[u] = 1;

    for(int v : G[u])
        if(v != fa) {
            dp(v, u);

            int nlim = std::min(lim[u] + lim[v], K);
            for(int a = 0; a <= lim[u]; a ++)
                for(int b = 0; b <= lim[v] and a + b <= nlim; b ++) {
                    for(int p1 = 0; p1 < 2; p1 ++)
                        for(int q1 = 0; q1 < 2; q1 ++)
                            for(int p2 = 0; p2 < 2; p2 ++)
                                for(int q2 = 0; q2 < 2; q2 ++)
                                    if(q2 | p1)
                                        __a(g[a + b][p1][q1 | p2] +=
                                         1ll * f[u][a][p1][q1] * f[v][b][p2][q2] % mod);
                }

            for(int k = 0; k <= nlim; k ++)
                for(int p = 0; p < 2; p ++)
                    for(int q = 0; q < 2; q ++)
                        f[u][k][p][q] = g[k][p][q];
            for(int k = 0; k <= nlim; k ++)
                g[k][0][0] = g[k][0][1] = g[k][1][0] = g[k][1][1] = 0;
            lim[u] = nlim;
        }
}

int main() {
    scanf("%d %d", &N, &K);
    for(int i = 1; i < N; i ++) {
        int u, v;
        scanf("%d %d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dp(1, 0);
    printf("%d\n", (f[1][K][0][1] + f[1][K][1][1]) % mod);
}