Skip to content

Centroid Decomposition

Centroid decomposition: removing the centroid splits the tree into balanced components

The Problem: Counting Paths of a Given Length

You have a tree with \(n\) nodes and edge weights. Count all pairs of nodes \((u, v)\) whose path has total weight exactly \(K\).

Brute force: enumerate all \(\binom{n}{2}\) pairs, compute each path length. That's \(O(n^2)\) at minimum. For \(n = 10^5\), too slow.

Centroid decomposition solves this in \(O(n \log n)\).


What Is a Centroid?

The centroid of a tree is a node whose removal splits the tree into connected components, each of size at most \(n/2\).

Every tree has at least one centroid (and at most two). Finding it is straightforward:

  1. Root the tree anywhere. Compute subtree sizes.
  2. Walk from the root toward the child with the largest subtree until no child's subtree exceeds \(n/2\).
  3. That node is the centroid.

The centroid is the "balance point" of the tree — no direction from it contains more than half the nodes.


The Decomposition

Build a centroid tree (a new tree, different from the original):

  1. Find the centroid \(c\) of the whole tree.
  2. Remove \(c\). The tree splits into several components.
  3. Recursively find the centroid of each component.
  4. Make \(c\) the parent of each component's centroid in the centroid tree.

Depth of the centroid tree: \(O(\log n)\). Each level removes one centroid, and each component has at most \(n/2\) nodes, so the depth at most doubles with each level.

Every path in the original tree passes through some node's "level" in the centroid tree. Specifically, the path from \(u\) to \(v\) passes through the LCA of \(u\) and \(v\) in the centroid tree.

This property is what makes centroid decomposition powerful: to process all paths, you only need to process paths passing through each centroid, and there are \(O(\log n)\) levels.


Trace: Decomposing a 7-Node Tree

    1
   / \
  2   3
 / \   \
4   5   6
        |
        7

All edges have weight 1. We want to count pairs with distance exactly 3.

Step 1: Find centroid of the whole tree (7 nodes).

Subtree sizes rooted at 1: node 1 has size 7, node 2 has size 3, node 3 has size 3.

Check node 1: largest component after removal is max(3, 3) = 3 ≤ 7/2 = 3. Node 1 is the centroid.

Step 2: Remove node 1. Components: {2, 4, 5} and {3, 6, 7}.

Find centroid of {2, 4, 5}: node 2 (removal gives components of size 1, 1). Centroid = 2.

Find centroid of {3, 6, 7}: node 3 or 6. Sizes: remove 3 → component {6, 7} size 2 ≤ 1.5? No. Remove 6 → components {3} and {7}, both size 1 ≤ 1.5. Centroid = 6. (Actually, let's recheck: 3 nodes, n/2 = 1.5. Remove 3: component {6,7} has size 2 > 1.5. Remove 6: components {3} size 1, {7} size 1. Both ≤ 1.5. Centroid = 6.)

Centroid tree:

       1
      / \
     2   6
    / \   \
   4   5   3
            \
             7

Step 3: Count paths of length 3 through each centroid.

At centroid 1: collect distances from 1 to all nodes in original tree. - Distance to 2: 1, to 4: 2, to 5: 2, to 3: 1, to 6: 2, to 7: 3

Pairs summing to 3: (dist 1, dist 2) pairs. Node 2 is dist 1, nodes 4,5,6 are dist 2. That gives pairs (2,4), (2,5), (2,6). Node 3 is dist 1, nodes 4,5,6 are dist 2. Pairs (3,4), (3,5), (3,6). Also (1,7) since dist to 7 is 3 → pair with distance 0 node (node 1 itself, but we want pairs).

Wait — we need to be careful. Pairs passing through centroid 1 with distance 3: - (2, 6): dist(2,1) + dist(1,6) = 1 + 2 = 3. Yes. - (2, 7): dist = 1 + 3 = 4. No. - (3, 4): dist = 1 + 2 = 3. Yes. - (3, 5): dist = 1 + 2 = 3. Yes. - (4, 6): dist = 2 + 2 = 4. No.

But we must subtract paths that go through centroid 1 but stay within one component (those are fake — the path doesn't actually go through 1).

Distances from 1 in left component {2,4,5}: 1, 2, 2. Distances from 1 in right component {3,6,7}: 1, 2, 3.

Count pairs across components summing to 3: from left, dist 1 (one node) pairs with right dist 2 (one node) → 1 pair. From left, dist 2 (two nodes) pairs with right dist 1 (one node) → 2 pairs. Total through centroid 1: 3 pairs.

Continue at deeper centroids to count remaining paths. The total over all centroids gives the answer.


The Code

void computeSize(int node, int parentNode) {
    subtreeSize[node] = 1;
    for (auto [neighbor, weight] : adj[node]) {
        if (neighbor == parentNode || removed[neighbor]) continue;
        computeSize(neighbor, node);
        subtreeSize[node] += subtreeSize[neighbor];
    }
}

int findCentroid(int node, int parentNode, int treeSize) {
    for (auto [neighbor, weight] : adj[node]) {
        if (neighbor == parentNode || removed[neighbor]) continue;
        if (subtreeSize[neighbor] > treeSize / 2) {
            return findCentroid(neighbor, node, treeSize);
        }
    }
    return node;
}

void collectDistances(int node, int parentNode, int currentDist,
                      vector<int>& distances) {
    distances.push_back(currentDist);
    for (auto [neighbor, weight] : adj[node]) {
        if (neighbor == parentNode || removed[neighbor]) continue;
        collectDistances(neighbor, node, currentDist + weight, distances);
    }
}

int countPairsWithSum(vector<int>& distances, int targetSum) {
    sort(distances.begin(), distances.end());
    int left = 0, right = (int)distances.size() - 1;
    int pairCount = 0;

    while (left < right) {
        int currentSum = distances[left] + distances[right];
        if (currentSum == targetSum) {
            int leftVal = distances[left], rightVal = distances[right];
            if (leftVal == rightVal) {
                int count = right - left + 1;
                pairCount += count * (count - 1) / 2;
                break;
            }
            int leftCount = 0, rightCount = 0;
            while (left <= right && distances[left] == leftVal) { left++; leftCount++; }
            while (right >= left && distances[right] == rightVal) { right--; rightCount++; }
            pairCount += leftCount * rightCount;
        } else if (currentSum < targetSum) {
            left++;
        } else {
            right--;
        }
    }
    return pairCount;
}

void solve(int node, int targetDist) {
    computeSize(node, -1);
    int centroid = findCentroid(node, -1, subtreeSize[node]);
    removed[centroid] = true;

    vector<int> allDistances;
    allDistances.push_back(0);

    for (auto [neighbor, weight] : adj[centroid]) {
        if (removed[neighbor]) continue;

        vector<int> componentDistances;
        collectDistances(neighbor, centroid, weight, componentDistances);

        totalPairsWithDistK -= countPairsWithSum(componentDistances, targetDist);

        for (int dist : componentDistances) {
            allDistances.push_back(dist);
        }
    }

    totalPairsWithDistK += countPairsWithSum(allDistances, targetDist);

    for (auto [neighbor, weight] : adj[centroid]) {
        if (removed[neighbor]) continue;
        solve(neighbor, targetDist);
    }
}

Complexity: \(O(n \log^2 n)\). At each level of the centroid tree (\(O(\log n)\) levels), we collect and sort distances (\(O(n \log n)\) total across the level). With a different counting method (e.g., FFT or a map), you can get \(O(n \log n)\).


The Subtraction Trick

Notice the inclusion-exclusion in the code. We count pairs from all distances to the centroid (including pairs within the same component), then subtract pairs that are entirely within one component.

Without this subtraction, you'd double-count paths that don't actually pass through the centroid. This is the standard technique for centroid decomposition counting problems.


Try It

Input: 7 nodes, all edge weights 1
Edges: (1,2),(1,3),(2,4),(2,5),(3,6),(6,7)
K = 3
Output: 3 (pairs: (2,6), (3,4), (3,5))

Predict before running: On a path graph 1-2-3-4-5 with unit weights and \(K = 2\), the answer is 3 (pairs (1,3), (2,4), (3,5)). The centroid is node 3.

Challenge: Modify the code to find the closest pair of nodes (minimum distance between any two distinct nodes). Hint: at each centroid, sort the distances and check adjacent elements.

Edge Cases to Watch For

  • \(K = 0\) in distance counting: Every node forms a pair with itself if self-pairs are counted, or no pairs if only distinct pairs are counted. The problem statement determines which, and your counting logic must match — off-by-\(n\) errors come from this ambiguity.
  • Forgetting to re-compute subtree sizes after centroid removal: After removing the centroid and recursing into each component, the subtree sizes from the original tree are stale. You must recompute sizes within each component, or the centroid-finding step picks the wrong node and breaks the \(O(\log n)\) depth guarantee.

Problems

Problem Link Difficulty
POJ 1741 — Tree poj.org/problem?id=1741 Hard
CF 342E — Xenia and Tree codeforces.com/contest/342/problem/E Hard
CF 321C — Ciel and Gondola codeforces.com/contest/321/problem/C Hard