[题解]1912异象石

题目描述

problem.png

题解

首先考虑如何求树上的任意个点连通的边集的总长度的最小值

可以对树上的点进行一次$dfs$求出它们的$dfs$序,将所求点按$dfs$序排序后相邻两点(包括头尾)的距离之和即为使这些点连通的边集的总长度的最小值的两倍

如图(蓝色为各点$dfs$序)

Inkedgraph _1__LI.jpg

假设我们要让$1,2,3,5$连通

按$dfs$序排序后得到$1,3,5,2$

$dis(1,3)+dis(3,5)+dis(5,2)+dis(2,1)=5+15+11+1$

观察发现每条边都被计算两次,所以结果是答案的$2$倍

有了以上结论后,我们可以得出一个算法,使用平衡树或$set$维护所有异象石的$dfs$序,每次插入一个新的异象石$k$时,查找到它在已有异象石中的前驱和后继,假设为$i,j$,更新$ans$,令其减去$dis(i,j)$并加上$dis(i,k)+dis(k,j)$。同理,删除操作则令$ans$加上$dis(i,j)$减去$dis(i,k)+dis(k,j)$

注意特判边界

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#include <bits/stdc++.h>
using namespace std;

const int MAXN = 1e5 + 10;

int read()
{
int sum = 0, f = 1;
char c = getchar();
while (c < '0' or c > '9')
{
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' and c <= '9')
{
sum = sum * 10 + c - '0';
c = getchar();
}
return sum * f;
}

int n, m, cnt, head[MAXN];
int dfn[MAXN];
long long dist[MAXN], tot, f[30][MAXN], depth[MAXN];
struct G
{
int To, Nxt, Dis;
} edge[MAXN << 1];

void add(int From, int To, int Dis)
{
edge[++cnt].Dis = Dis;
edge[cnt].Nxt = head[From];
edge[cnt].To = To;
head[From] = cnt;
}

void dfs(int k, int fa)
{
tot++;
dfn[k] = tot;
f[0][k] = fa;
depth[k] = depth[fa] + 1;
for (int i = head[k]; i; i = edge[i].Nxt)
{
int v = edge[i].To;
if (v == fa)
continue;
dist[v] = dist[k] + edge[i].Dis;
dfs(v, k);
}
}

void prework()
{
for (int j = 1; j <= 20; j++)
for (int i = 1; i <= n; i++)
f[j][i] = f[j - 1][f[j - 1][i]];
}

int Lca(int x, int y)
{
if (depth[x] < depth[y])
swap(x, y);
for (int j = 20; j >= 0; j--)
if (depth[f[j][x]] >= depth[y])
x = f[j][x];
if (x == y)
return x;
for (int j = 20; j >= 0; j--)
if (f[j][x] != f[j][y])
x = f[j][x], y = f[j][y];
return f[0][x];
}

set<int> s;
int id[MAXN];
char ch;
long long ans;

long long calc(int x, int y)
{
int t = Lca(x, y);
return dist[x] + dist[y] - 2 * dist[t];
}

int main()
{
n = read();
for (int i = 1; i < n; i++)
{
int x = read(), y = read(), z = read();
add(x, y, z), add(y, x, z);
}
dfs(1, 0);
for (int i = 1; i <= n; i++)
id[dfn[i]] = i;
prework();
m = read();
while (m--)
{
scanf(" ");
scanf("%c", &ch);
if (ch == '+')
{
int x = read();
if (s.empty())
{
s.insert(dfn[x]);
continue;
}
set<int>::iterator it = s.lower_bound(dfn[x]);
if (it == s.begin())
{
int y = *it;
set<int>::iterator it1 = --s.end();
int z = *it1;
ans += (calc(x, id[y]) + calc(x, id[z]) - calc(id[y], id[z]));
}
else
{
if (it == s.end())
{
it--;
int y = *it;
int z = *s.begin();
ans += (calc(x, id[y]) + calc(x, id[z]) - calc(id[y], id[z]));
}
else
{
int y = *it;
if (it != s.begin())
it--;
int z = *it;
ans += (calc(x, id[y]) + calc(x, id[z]) - calc(id[y], id[z]));
}
}
s.insert(dfn[x]);
}
if (ch == '-')
{
int x = read();
if (s.size() == 1)
{
s.erase(dfn[x]);
continue;
}
set<int>::iterator it = s.find(dfn[x]);
if (it == s.begin())
{
it++;
int y = *it;
set<int>::iterator it1 = --s.end();
int z = *it1;
ans -= (calc(x, id[y]) + calc(x, id[z]) - calc(id[y], id[z]));
}
else
{
set<int>::iterator End = s.end();
End--;
if (it == End)
{
if (it != s.begin())
it--;
int y = *it;
int z = *s.begin();
ans -= (calc(x, id[y]) + calc(x, id[z]) - calc(id[y], id[z]));
}
else
{
it++;
int y = *it;
it--, it--;
int z = *it;
ans -= (calc(x, id[y]) + calc(x, id[z]) - calc(id[y], id[z]));
}
}
s.erase(dfn[x]);
}
if (ch == '?')
{
printf("%lld\n", ans / 2);
}
}
return 0;
}