Skip to content

DSU on Tree

DSU on tree: keep the heavy child's data, re-add light children

The Problem: Distinct Colors Per Subtree

Every node in a tree has a color. For each node, you want to know: how many distinct colors appear in its subtree?

Naive approach: for each node, DFS its subtree and count distinct colors. That's \(O(n)\) per node, \(O(n^2)\) total. With \(n = 10^5\), you need something better.

Small-to-large merging (also called DSU on tree) solves this in \(O(n \log n)\).


The Technique

The idea is deceptively simple. When you're at a node and you've processed all its children, you need to combine their color sets. The trick: always merge the smaller set into the larger one.

Why does this help? Every time an element gets merged, it moves into a set that's at least twice as large. So each element can be merged at most \(O(\log n)\) times across the entire algorithm. Total work: \(O(n \log n)\).

But we don't literally maintain sets. Instead, we use a global frequency array and a specific DFS order.


How It Works in Practice

For each node, designate its heavy child — the child with the largest subtree (same as HLD). Process the heavy child last.

The DFS does this:

  1. Recurse on all light children first. After processing each light child, undo its contribution to the global count (clear its subtree from the frequency array).
  2. Recurse on the heavy child. Do NOT undo its contribution — keep its data in the global count.
  3. Now add all light children's subtrees back into the global count (on top of the heavy child's data).
  4. Add the current node itself.
  5. Record the answer for this node.

The heavy child's data survives, so we don't re-process the largest subtree. Only the light subtrees get re-added. Since each node's light subtrees total at most \(n/2\) of its subtree, the total re-processing across the tree is \(O(n \log n)\).

The key insight: keep the heavy child's data, re-add the light children. Each element is re-added \(O(\log n)\) times because it can only be in a light subtree \(O(\log n)\) times on any root-to-leaf path.


Trace: 7-Node Tree with Colors

       1 (red)
      / \
     2   3 (blue)
   (red)  \
   / \    6 (red)
  4   5   |
(blue)(green) 7 (blue)

Colors: 1=red, 2=red, 3=blue, 4=blue, 5=green, 6=red, 7=blue.

Subtree sizes and heavy children:

Node Subtree size Children sizes Heavy child
1 7 2→3, 3→3 either (say 2)
2 3 4→1, 5→1 either (say 4)
3 3 6→2 6
6 2 7→1 7

Processing order (light children first, then heavy):

At node 2: process light child 5, clear it, then process heavy child 4 (keep data), re-add 5, add node 2.

At node 3: process heavy child 6 (which itself processes 7 first, keeps data), add node 3.

At node 1: process light child 3, clear it, then process heavy child 2 (keep data), re-add 3's subtree, add node 1.

Step-by-step global frequency tracking:

Step Action Color freq (R,B,G) Distinct
Process node 4 (heavy of 2) add 4(blue) (0,1,0) 1
Process node 5 (light of 2) — then clear add 5(green) → (0,1,1), distinct=2, then clear 5 (0,1,0) 1
Re-add node 5 for node 2 add 5(green) (0,1,1) 2
Add node 2 add 2(red) (1,1,1) 3
Answer for node 2: 3
Clear subtree of node 2 (light child of 1) remove 2,4,5 (0,0,0) 0
Process node 7 (heavy of 6) add 7(blue) (0,1,0) 1
Add node 6 add 6(red) (1,1,0) 2
Answer for node 6: 2
Add node 3 add 3(blue) (1,2,0) 2
Answer for node 3: 2 keep data (3 is heavy child of 1)
Re-add subtree of node 2 add 2(red),4(blue),5(green) (2,3,1) 3
Add node 1 add 1(red) (3,3,1) 3
Answer for node 1: 3

Final answers:

Node Distinct colors in subtree
1 3 (red, blue, green)
2 3 (red, blue, green)
3 2 (blue, red)
4 1 (blue)
5 1 (green)
6 2 (red, blue)
7 1 (blue)

The Complexity Proof

Why is the total work \(O(n \log n)\)?

Consider any node \(v\). How many times does \(v\) get added to and removed from the global count across the entire algorithm?

\(v\) gets cleared only when its subtree's data is discarded — which happens when \(v\)'s ancestor is a light child of its parent. On any root-to-leaf path, there are at most \(O(\log n)\) light edges (same argument as HLD: crossing a light edge means the subtree size at least doubles going up). So \(v\) is added/removed at most \(O(\log n)\) times.

Total: \(n\) nodes \(\times\) \(O(\log n)\) operations each = \(O(n \log n)\).


The Code

void computeSizes(int node, int parentNode) {
    subtreeSize[node] = 1;
    heavyChild[node] = -1;
    int maxChildSize = 0;

    for (int neighbor : adj[node]) {
        if (neighbor == parentNode) continue;
        computeSizes(neighbor, node);
        subtreeSize[node] += subtreeSize[neighbor];
        if (subtreeSize[neighbor] > maxChildSize) {
            maxChildSize = subtreeSize[neighbor];
            heavyChild[node] = neighbor;
        }
    }
}

void addSubtree(int node, int parentNode, int delta) {
    colorCount[color[node]] += delta;
    if (delta == 1 && colorCount[color[node]] == 1) currentDistinct++;
    if (delta == -1 && colorCount[color[node]] == 0) currentDistinct--;

    for (int neighbor : adj[node]) {
        if (neighbor == parentNode) continue;
        addSubtree(neighbor, node, delta);
    }
}

void dfs(int node, int parentNode, bool keepData) {
    for (int neighbor : adj[node]) {
        if (neighbor == parentNode || neighbor == heavyChild[node]) continue;
        dfs(neighbor, node, false);
    }

    if (heavyChild[node] != -1) {
        dfs(heavyChild[node], node, true);
    }

    for (int neighbor : adj[node]) {
        if (neighbor == parentNode || neighbor == heavyChild[node]) continue;
        addSubtree(neighbor, node, 1);
    }

    colorCount[color[node]]++;
    if (colorCount[color[node]] == 1) currentDistinct++;

    answer[node] = currentDistinct;

    if (!keepData) {
        addSubtree(node, parentNode, -1);
    }
}

Complexity: \(O(n \log n)\) time, \(O(n)\) space.


Why "DSU on Tree"?

The name comes from the merge pattern. In union-find (DSU), the "union by size" heuristic always merges the smaller set into the larger one — identical to what we do here. The tree structure just determines which sets to merge.

You'll also see this technique called small-to-large merging or the heavy-light trick for subtree queries. All the same idea.


Try It

Input: 5 nodes, colors [1, 2, 1, 3, 2]
Edges: (1,2),(1,3),(3,4),(3,5)
Output: answers [3, 1, 3, 1, 1]

Predict before running: If every node has the same color, what is the answer for every node? It's 1 everywhere — distinct count doesn't grow.

Challenge: Modify the problem to find, for each node, the most frequent color in its subtree. The DSU-on-tree framework is the same — just change what you track in the global state.

Application: Interaction Costs in Groups (LC 3786)

The distinct-colors problem counts values. This variant sums distances — and the small-to-large merge is the key to making it efficient.

Problem: Every node belongs to a group. For each pair of same-group nodes, the interaction cost is the path length between them. Return the total cost across all same-group pairs.

The edge contribution insight: For any edge separating a subtree from the rest of the tree, if group \(g\) has \(c\) members below the edge and \(T_g - c\) members above, then \(c \times (T_g - c)\) same-group paths cross this edge. Sum over all edges = total cost.

The implementation: DFS bottom-up. Each node maintains a frequency counter of groups in its subtree. When merging a child's counter into the parent, compute contributions BEFORE merging:

unordered_map<int,int> dfs(int node, int parent) {
    unordered_map<int,int> nodeCounter;
    nodeCounter[group[node]] = 1;

    for (int child : adj[node]) {
        if (child == parent) continue;
        auto childCounter = dfs(child, node);

        for (auto& [groupId, count] : childCounter)
            totalCost += (long long)count * (totalGroupCount[groupId] - count);

        if (childCounter.size() > nodeCounter.size())
            swap(nodeCounter, childCounter);
        for (auto& [groupId, count] : childCounter)
            nodeCounter[groupId] += count;
    }
    return nodeCounter;
}

The swap before merge is the small-to-large trick — always merge the smaller map into the larger one. Each element is merged \(O(\log n)\) times across the entire tree, giving \(O(n \log n)\) total.

Why contribution BEFORE merge: When we see group \(g\) with count \(c\) in the child's subtree, the remaining \(T_g - c\) members are above the current edge. Every path between these two groups crosses this edge exactly once. After merging, we lose the separation — the counts blend together.


Edge Cases to Watch For

  • Forgetting to clear the global state after processing a light child: After recursing into a light child's subtree, you must undo all its contributions to the global frequency array before processing the next child. Failing to clear creates cross-contamination between subtrees.
  • Heavy child is the last child processed: The algorithm's correctness depends on processing the heavy child last (so its data persists in the global state). If your adjacency list ordering accidentally processes the heavy child first, the light children overwrite its data.

Problems

Problem Link Difficulty
LC 3786 — Interaction Costs in Groups leetcode.com/problems/total-sum-of-interaction-cost-in-tree-groups Hard
CF 600E — Lomsat gelral codeforces.com/contest/600/problem/E Medium
CF 570D — Tree Requests codeforces.com/contest/570/problem/D Medium
CF 208E — Blood Cousins codeforces.com/contest/208/problem/E Hard