虚树学习笔记

这玩意怎么变成 8 级了/ll

SDOI2020

problem

给定一颗 nn 个节点的树,每个边删掉后都有 aia_i 的代价。一共 mm 次询问,每次给定 kk 个关键点,求使得 11 号节点和所有关键点都不连通的最小代价。n2.5×105,m5×105,k5×105n \le 2.5 \times 10^5 ,m \le 5 \times 10^5, \sum k \le 5\times 10^5

暴力 DP

首先看到这个问题第一眼,我们都会一个 DP 。每个询问单独考虑,记 dpidp_iii 不与 ii 子树内的任意关键点联通的最小代价,DP 的过程也很简单,枚举 ii 的每个儿子 jj ,记 i,ji,j 之间的边权为 wi,jw_{i,j} ,那么有:

  • jj 是关键点,那么 dpi=dpi+wi,jdp_i = dp_i + w _ {i,j}
  • 否则,dpi=dpi+min(wi,j,dpj)dp_i = dp_i + \min (w_{i,j},dp _j)

最后 dp1dp_1 就是答案。这样的 DP 时间复杂度是 O(nm)O(nm) 的,考虑优化。

建树

首先,我们猜测只有这三类点是有用的:

  1. 关键点
  2. 任意两个关键点的 LCA

前两者很好理解,这里证明一下第三者。

首先,一个关键点会将信息传递给自己到根的路径上的所有点。那么一点会与有 xx 条路径经过的儿子的 dpdp 值不同,当且仅当其是某大于等于 x+1x+1 条路径的交点。那么除了根和关键点以外只有两个关键点的 LCA 有贡献。

于是,我们就可以把原来树上的无用点的信息都压到有用点上的,这样缩成的一颗节点较少的树就叫做 虚树 ,在虚树上进行操作。

对于这道题,我们很容易想到可以枚举每一对关键点,然后求出其 LCA 并建树,但是这样的复杂度是 O(k2logn)O(k^2 \log n) ,极端情况下甚至比朴素 DP 还要糟糕,那么我们就需要更快的建树方式。

之所以枚举的复杂度为 O(k2logn)O(k^2\log n) ,是因为在枚举 {a1,a2}\{a_1,a_2\} 时,会出现大量重复的 lca(a1,a2)\text{lca}(a_1,a_2)

我们建立点集 AA ,并把所有键点和 11 加入 AA 。接着将关键点按 DFS 序排序,再遍历关键点序列,把相邻两点 xxyy 的 LCA 也加入序列 AA

由于 DFS 序的特性,此时 AA 已包含了所有 虚树上的点 ,接下来只要考虑如何用 AA 建树即可。

AA 再次排序并去重,再遍历一遍 AA ,对于相邻点 xxyy ,连接 lca(x,y)\text{lca}(x,y)yy ,最终得到的就是虚树。

考虑为什么可以直接连 lca(x,y)\text{lca}(x,y)

这是因为,对于点 xxyy

  • xxyy 的祖先节点,则 lca(x,y)\text{lca}(x,y) 即为 xx 。由于 不存在zA{x,y}z \in \complement_A{\{x,y\}} 的 DFS 序在 xxyy 之间,故 xxyy 的路径上 没有 其它 AA 中的点,可以直接连。
  • xx 不是 yy 的祖先节点,对于 lca(x,y)\text{lca}(x,y)yy ,同样 不存在 zA{x,y}z \in \complement_A{\{x,y\}}lca(x,y)\text{lca}(x,y)yy 的路径之间,可以直接连。
    这里我们可以用反证法:若存在 zA{x,y}z \in \complement_A{\{x,y\}}lca(x,y)\text{lca}(x,y)yy 的路径之间,则 dfn(x)<dfn(z)<dfn(y)dfn(x)<dfn(z)<dfn(y),不满足 AA 按 DFS 排序后的性质,故命题得证。

这样,我们就在 O(klogn)O(k \log n) 的时间内建出了一颗虚树,且点的数量级是 O(k)O(k) 的。

DP

考虑把原来的那个 DP 丢到虚树上来做。

wiw_i 为在原来的树上, ii11 的路径权值最小值, dpidp_i 为使 ii 不与其子树内的点联通的最小代价。

对于每个 ii ,枚举其儿子 jj ,则有:

  • jj 是关键点,那么 dpi=dpi+wjdp_i=dp_i + w_j
  • jj 不是关键点,那么 dpi=dpi+min(dpj,wj)dp_i=dp_i + \min (dp_j,w_j)

这个也很好理解,因为只需要断开原来树上的边权最小的边就可以了。

dp1dp_1 就是答案。

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=1e6+9;
const int INF=1e18;
const int MOD=998244353;
int n;
vector<int> G[maxn];
int m;
unordered_map<int,int> a[maxn];
int w[maxn];
int fa[maxn][30];
int dep[maxn];
int dfn[maxn];
int tot;
void dfs(int u,int f)
{
dfn[u]=++tot;
fa[u][0]=f;
dep[u]=dep[f]+1;
w[u]=min(a[f][u],w[f]);
for(int i=1;i<30;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
for(auto v:G[u]) if(v!=f) dfs(v,u);
}
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int j=29;j>=0;j--) if(dep[fa[x][j]]>=dep[y]) x=fa[x][j];
if(x==y) return x;
for(int j=29;j>=0;j--) if(fa[x][j]!=fa[y][j]) {x=fa[x][j];y=fa[y][j];}
return fa[x][0];
}
int k;
bool cmp(int x,int y) {return dfn[x]<dfn[y];}
map<int,bool> guan;
vector<int> b;
vector<int> A;
map<int,vector<int>> e;
int dp(int u)
{
if(guan[u]) return w[u];
int tmp=0;
for(auto v:e[u]) tmp+=dp(v);
return min(tmp,w[u]);
}
void solve()
{
b.clear(),A.clear();
guan.clear();
e.clear();
cin>>k;
for(int i=1;i<=k;i++) {int x;cin>>x;guan[x]=1;b.push_back(x);}
A.push_back(1);
for(auto v:b) A.push_back(v);
sort(b.begin(),b.end(),cmp);
for(int i=0;i<b.size()-1;i++) A.push_back(lca(b[i],b[i+1]));
sort(A.begin(),A.end(),cmp);
//cout<<'g';for(auto v:A) cout<<v<<' ';cout<<'g'<<'\n';
int tot=unique(A.begin(),A.end())-A.begin();
//cout<<tot<<'\n';
//cout<<'g';for(auto v:A) cout<<v<<' ';cout<<'g'<<'\n';
for(int i=0;i<tot-1;i++) e[lca(A[i],A[i+1])].push_back(A[i+1]);
//for(int i=1;i<=n;i++) for(auto v:e[i]) cout<<i<<' '<<v<<'\n';
cout<<dp(1)<<'\n';
}
signed main()
{
//freopen("test.in","r",stdin);
//freopen("test.out","w",stdout);
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
cin>>n;
for(int i=0;i<=n;i++) w[i]=INF;
a[0][1]=INF;
for(int i=1;i<n;i++)
{
int u,v,d;
cin>>u>>v>>d;
G[u].push_back(v);G[v].push_back(u);
a[u][v]=d,a[v][u]=d;
}
dfs(1,0);
//for(int i=1;i<=n;i++) cout<<fa[i][0]<<' ';cout<<'\n';
//for(int i=1;i<=n;i++) cout<<w[i]<<' ';cout<<'\n';
cin>>m;
while(m--) solve();
return 0;
}