1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
| #include <cstdio> #include <vector>
const int MaxN = 300050, MaxK = 20;
struct Solver {
std::vector<int> g[MaxN]; int f[MaxK][MaxN], in[MaxN], out[MaxN], id[MaxN], tim; void dfs(int u) { id[in[u] = ++tim] = u; for (int v : g[u]) if (!in[v]) { f[0][v] = u; dfs(v); } out[u] = tim; }
struct Segment {int l,r;};
int total(const std::vector<Segment> &l) { int sum = 0; for (const Segment &l0 : l) sum += l0.r-l0.l+1; return sum; } int kth(int k, const std::vector<Segment> &l) { for (const Segment &l0 : l) { int len = l0.r-l0.l+1; if (k>len) k-=len; else return l0.l+k-1; } return -1; } int subTreeSize(int u, const std::vector<Segment> &l) { int siz = 0; for (const Segment &l0 : l) { int L = std::max(in[u], l0.l), R = std::min(out[u], l0.r); siz += std::max(R-L+1, 0); } return siz; }
int K; int getCenter(int u, int S, const std::vector<Segment> &l) { if (2*subTreeSize(u, l) >= S) return u; for (int k=K; k>=0; k--) if (f[k][u] && 2*subTreeSize(f[k][u], l) < S) u = f[k][u]; return f[0][u]; } int calc(const std::vector<Segment> &l) { int S = total(l), mid = kth(S/2+1, l), t = id[mid], cent = getCenter(t, S, l); if (2*subTreeSize(cent, l) < S) puts("!"); return (2*subTreeSize(cent, l) == S) ? cent + f[0][cent] : cent; }
void solve() { int N; scanf("%d", &N); while((1<<(K+1))<N) K++; for (int i=1; i<N; i++) { int u, v; scanf("%d%d", &u, &v); g[u].push_back(v); g[v].push_back(u); } dfs(1); for (int k=1; k<=K; k++) for (int u=1; u<=N; u++) f[k][u] = f[k-1][f[k-1][u]]; long long sum = 0; std::vector<Segment> l; for (int u=2; u<=N; u++) { l.clear(); l.push_back((Segment){in[u], out[u]}); sum += calc(l); l.clear(); l.push_back((Segment){1,in[u]-1}); l.push_back((Segment){out[u]+1,N}); sum += calc(l); } printf("%lld\n", sum); }
}solver;
int main() { int T; scanf("%d", &T); while(T--) { solver = Solver(); solver.solve(); } return 0; }
|