树的直径

求树的直径有两种方法,一种是树形dp,一种是两次dfs/bfs

性质

树的直径有一重要性质,在无负权边的图中,所有直径的中点相同。

树形dp

设 $d[x]$ 表示从 $x$ 开始到 $x$ 的子树中的最大距离,设 $f[x]$ 表示通过 $x$ 的最长链长度,那么有 $f[x]=max(d[y_i]+d[y_j]+w(x,y_i)+w(x,y_j))$ 。$f[x]$ 可以在dfs时直接维护。

1
2
3
4
5
6
7
8
9
10
11
12
void dp(int x)
{
v[x]=1;
for(int i=hed[x];i;i=nxt[i])
{
int y=to[i];
if(v[y])continue;
dfs(y);
ans=max(ans,d[x]+d[y]+w[i]);
d[x]=max(d[x],d[y]+w[i]);
}
}

两次dfs

从根开始dfs找到最远点,然后从最远点dfs找到的最长路径就是直径。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include<bits/stdc++.h>
using namespace std;
#define int long long

const int N=2e5+5;
const int mod=998244353;
int n;
int cnt,hed[N*2],nxt[N*2],to[N*2];
int d[N];

int dmx,vt,zj;

void add(int x,int y)
{
to[++cnt]=y;
nxt[cnt]=hed[x];
hed[x]=cnt;
}

void dfs(int u,int fa)
{
for(int i=hed[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa)continue;
d[v]=d[u]+1;

if(d[v]>dmx)
{
dmx=d[v];
vt=v;
}
dfs(v,u);

}
}

signed main()
{
ios_base::sync_with_stdio(false);
cin>>n;
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
add(x,y);
add(y,x);
}

dmx=0;
dfs(1,0);

dmx=0;memset(d,0,sizeof(d));
dfs(vt,0);

zj=dmx;

cout<<zj<<endl;

}


最近公共祖先

采用倍增法求最近公共祖先,用 $fa[x][k]$ 表示 $x$ 向上走 $2^k$ 步到达的祖先,那么有 $fa[x][k]=fa[fa[x][k-1]][k-1]$ 。在求 $lca(x,y)$ 时,首先将两点深度调整为一样,然后两个点同时往上跳 $2^k$ 步,相当于把两点与祖先的距离二进制拆分,每一步跳一位二进制1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include<bits/stdc++.h>
using namespace std;
//
int n,m,s,t;
int hed[1000005],nxt[1000005],to[1000005],cnt;
int d[500005];
int fa[500005][25];

void add(int u,int v)
{
to[++cnt]=v;
nxt[cnt]=hed[u];
hed[u]=cnt;
}

void bfs()
{
queue<int>q;
d[s]=1;
q.push(s);
while(q.size())
{
int u=q.front();
q.pop();
for(int i=hed[u];i;i=nxt[i])
{
int v=to[i];
if(d[v])continue;
d[v]=d[u]+1;
fa[v][0]=u;
for(int j=1;j<=t;j++)fa[v][j]=fa[fa[v][j-1]][j-1];
q.push(v);
}
}
}

int lca(int x,int y)
{
if(d[x]>d[y])swap(x,y);
for(int i=t;i>=0;i--)if(d[fa[y][i]]>=d[x])y=fa[y][i];
if(x==y)return x;
for(int i=t;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return fa[x][0];
}

int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>n>>m>>s;
t=(int)(log(n)/log(2))+1;
for(int i=1;i<n;i++)
{
int x,y;
cin>>x>>y;
add(x,y);
add(y,x);
}

bfs();

for(int i=1;i<=m;i++)
{
int x,y;
cin>>x>>y;
cout<<lca(x,y)<<endl;
}
return 0;
}