题解 CF1276F 【Asterisk Substrings】

Kewth

2020-03-21 15:08:36

Solution

把要统计的串分为以下几类:`(empty)`, `s`, `s*t`, `s*`, `*s`, `*` 。 `s` 的数量就是本质不同子串数,做法很经典。 而 `s*` 就是去掉最后一个点后的本质不同子串数,同理 `*s` 就是去掉第一个点。 `(empty)`, `*` 的数量就是一。 那么这题的重点就是 统计 `s*t` 的数量。 --- 考虑枚举本质不同的 s ,统计对应的 t 的数量。 不难发现 **对于 endpos 集合相同的 s ,其对应的 t 的数量是相同的** 。那么可以在 SAM 上枚举 $O(n)$ 个节点来计算 t 的数量。 设 s 的 endpos 集合为 $X$ ,对应地设集合 $Y = \{y=x+2|x \in X\}$ 。那么把起点在集合 $Y$ 的所有后缀拿来建 Trie ,Trie 的节点数量就是合法 t 的数量。 众所周知求 Trie 的节点数量可以子集容斥,但是复杂度太大。另一个容斥方法是按字典序从小到大插入字符串,设上一个插入的串为 m ,新插入串 n 后, Trie 树的节点数量会增加 $|n| - |lcp(n, m)|$ 。 也就是说只要把起点在集合 $Y$ 的后缀按字典序排序并且求出相邻 lcp 的和就可以解决问题。 维护每个点的集合 $Y$ ,由于 SAM 上一个点的 endpos 集合是其在 parent 树上的儿子的 endpos 集合的并,因此可以通过启发式合并来维护每个点的 endpos 集合,也就能维护每个点的集合 $Y$ 。 但是仅记录位置无法维护答案,集合 $Y$ 中事实上要记录的是每个后缀的 rank ,然后维护 lcp 的和就是在启发式合并的时候查询 height 数组上的区间最小值。 时间复杂度 $O(nlog^2n)$ ,瓶颈在于 set 启发式合并。 --- 参考实现: ```cpp #include <cstdio> #include <algorithm> #include <cstring> #include <vector> #include <set> #define debug(...) fprintf(stderr, __VA_ARGS__) typedef long long ll; // SA {{{ const int maxn = 100005; char s[maxn]; int sa[maxn], rank[maxn], hi[maxn]; int tmp[maxn], sa2[maxn], t[maxn]; void build (int n, int max) { for (int i = 1; i <= n; i ++) rank[i] = s[i], sa2[i] = i; for (int i = 1; i <= n; i ++) ++ t[rank[i]]; for (int i = 1; i <= max; i ++) t[i] += t[i - 1]; for (int i = n; i; i --) sa[t[rank[sa2[i]]] --] = sa2[i]; std::fill(t + 1, t + max + 1, 0); int tot = rank[sa[1]] = 1; for (int i = 2; i <= n; i ++) rank[sa[i]] = s[sa[i]] == s[sa[i - 1]] ? tot : ++ tot; for (int m = 1; tot < n; m <<= 1) { int p = 0; for (int i = 1; i <= m; i ++) sa2[++ p] = n - m + i; for (int i = 1; i <= n; i ++) if (sa[i] > m) sa2[++ p] = sa[i] - m; for (int i = 1; i <= n; i ++) ++ t[rank[i]]; for (int i = 1; i <= tot; i ++) t[i] += t[i - 1]; for (int i = n; i; i --) sa[t[rank[sa2[i]]] --] = sa2[i]; std::fill(t + 1, t + tot + 1, 0); std::copy(rank, rank + n + 1, tmp); rank[sa[1]] = tot = 1; for (int i = 2; i <= n; i ++) rank[sa[i]] = (tmp[sa[i]] == tmp[sa[i - 1]] and tmp[sa[i] + m] == tmp[sa[i - 1] + m]) ? tot : ++ tot; } for (int i = 1; i <= n; i ++) { int &k = hi[rank[i]] = std::max(0, hi[rank[i - 1]] - 1); while (s[i + k] == s[sa[rank[i] - 1] + k]) ++ k; } } // }}} // SAM {{{ const int maxk = 26; int ch[maxn << 1][maxk], len[maxn << 1], fa[maxn << 1], cp = 1; std::vector<int> G[maxn << 1]; int insert (int pre, int x) { int now = ++ cp; len[now] = len[pre] + 1; while (pre and !ch[pre][x]) { ch[pre][x] = now; pre = fa[pre]; } if (pre) { int preto = ch[pre][x]; if (len[preto] == len[pre] + 1) fa[now] = preto; else { int sp = ++ cp; len[sp] = len[pre] + 1; fa[sp] = fa[preto]; for (int i = 0; i < maxk; i ++) ch[sp][i] = ch[preto][i]; while (pre and ch[pre][x] == preto) { ch[pre][x] = sp; pre = fa[pre]; } fa[now] = fa[preto] = sp; } } else fa[now] = 1; return now; } // }}} int st[maxn][20], hb[maxn]; void stinit (int n) { for (int i = n; i; i --) { st[i][0] = hi[i]; for (int k = 1; i + (1 << k) - 1 <= n; k ++) st[i][k] = std::min(st[i][k - 1], st[i + (1 << k >> 1)][k - 1]); } for (int i = 2; i <= n; i ++) hb[i] = hb[i >> 1] + 1; } int stquery (int l, int r) { int k = hb[r - l + 1]; return std::min(st[l][k], st[r - (1 << k) + 1][k]); } struct Set { std::set<int> set; ll val; void insert (int k, int n) { if (set.count(k)) return; int l = *(-- set.lower_bound(k)); int r = *set.upper_bound(k); val += n - sa[k] + 1; if (l >= 1) val -= stquery(l + 1, k); if (r <= n) val -= stquery(k + 1, r); if (l >= 1 and r <= n) val += stquery(l + 1, r); set.insert(k); } void init (int n) { set.insert(0), set.insert(n + 1); } } set[maxn << 1]; ll ans; void dfs (int u, int n) { for (int v : G[u]) { dfs(v, n); if (set[v].set.size() > set[u].set.size()) std::swap(set[u], set[v]); for (int k : set[v].set) set[u].insert(k, n); } ans += 1ll * (len[u] - len[fa[u]]) * set[u].val; } int main () { scanf("%s", s + 1); int n = int(strlen(s + 1)); int sam = 1; for (int i = 1; i <= n; i ++) sam = insert(sam, s[i] - 'a'); build(n, 256); for (int i = 1; i <= cp; i ++) set[i].init(n); stinit(n); sam = 1; set[sam].insert(rank[2], n); for (int i = 1; i <= n; i ++) { sam = ch[sam][s[i] - 'a']; if(i + 2 <= n) set[sam].insert(rank[i + 2], n); } // s for (int i = 1; i <= cp; i ++) ans += len[i] - len[fa[i]]; // s*s for (int i = 2; i <= cp; i ++) G[fa[i]].push_back(i); dfs(1, n); // s* for (int i = 1; i <= cp; i ++) if (i != sam) ans += len[i] - len[fa[i]]; ans += set[1].val; // *s ++ ans; // * ++ ans; // (empty) printf("%lld\n", ans); } ```