DP on Trees Explained: Rerooting, Subtree States & Hard Problems
Tree DP shows up everywhere serious algorithmic work happens — network routing optimization, compiler AST analysis, file-system quota calculations, org-chart scheduling, and every hard competitive programming problem that involves hierarchical data. The moment your input has a recursive, parent-child shape and you need an optimal answer over the whole structure, tree DP is the tool that turns an exponential brute-force into a clean O(n) pass. Ignoring it means either writing slow code or missing the solution entirely.
The core problem tree DP solves is this: trees have no obvious 'left-to-right' order the way arrays do, so classic 1D DP doesn't translate directly. Instead, every node depends on what its children can contribute, and children depend on their own children. The trick is defining a state that captures everything a parent needs to know about a subtree — nothing more, nothing less — then writing a transition that combines children's states cleanly. Get that state definition right and the rest flows naturally.
By the end of this article you'll be able to define correct DP states on trees, implement post-order (bottom-up) tree DP for problems like maximum independent set and tree diameter, apply the rerooting technique to answer 'what if every node were the root?' in a single extra pass, spot the common state-definition mistakes that cause wrong answers, and walk into an interview and explain the O(n) rerooting technique confidently when the naive approach is O(n²).
Post-Order Tree DP: The Foundation — Subtree States and How to Combine Them
Every tree DP problem starts with the same question: 'What does a parent need to know about a subtree?' That answer becomes your DP state. Once you've nailed the state, you write a DFS, process children before the current node (post-order), and merge child results into the parent's state.
The classic warm-up is the Maximum Weight Independent Set on a tree: choose a subset of nodes with maximum total weight such that no two chosen nodes share an edge. This is NP-hard on general graphs but O(n) on trees because the tree structure lets you make decisions subtree by subtree.
For each node you track exactly two values: the best weight achievable in that node's subtree when the node IS included, and the best weight when it IS NOT included. If a node is included, none of its children can be. If it's excluded, each child can independently choose to be included or excluded — whichever is better. That two-value state is the entire insight.
The DFS runs in post-order: children are fully resolved before the parent looks at them. Each child hands up two numbers. The parent combines them with two simple recurrences. The root's answer is the max of its two values. Total work: O(n). No memoization table needed beyond the recursion stack itself.
import java.util.*; public class MaxWeightIndependentSet { // dp[node][0] = max weight in subtree rooted at 'node' when node is EXCLUDED // dp[node][1] = max weight in subtree rooted at 'node' when node is INCLUDED private static long[][] dp; private static List<List<Integer>> adjacencyList; private static int[] nodeWeight; public static void main(String[] args) { int totalNodes = 7; nodeWeight = new int[]{0, 10, 5, 15, 7, 6, 1, 20}; // 1-indexed; index 0 unused adjacencyList = new ArrayList<>(); for (int i = 0; i <= totalNodes; i++) { adjacencyList.add(new ArrayList<>()); } // Build undirected tree edges (will use parent tracking to avoid going back up) addEdge(1, 2); addEdge(1, 3); addEdge(2, 4); addEdge(2, 5); addEdge(3, 6); addEdge(3, 7); // // 1 (10) // / \ // 2 3 // (5) (15) // / \ / \ // 4 5 6 7 // (7) (6)(1)(20) dp = new long[totalNodes + 1][2]; // Start DFS from node 1, with no parent (-1) dfs(1, -1); long answer = Math.max(dp[1][0], dp[1][1]); System.out.println("Max Independent Set Weight: " + answer); // Show per-node DP values for clarity System.out.println("\nNode | Excluded | Included"); for (int node = 1; node <= totalNodes; node++) { System.out.printf(" %-4d %-9d %d%n", node, dp[node][0], dp[node][1]); } } // Post-order DFS: children are fully computed before parent uses their values private static void dfs(int currentNode, int parentNode) { // Base case: start with just this node's own contribution dp[currentNode][0] = 0; // excluded: contributes 0 dp[currentNode][1] = nodeWeight[currentNode]; // included: contributes own weight for (int neighbor : adjacencyList.get(currentNode)) { if (neighbor == parentNode) continue; // don't recurse back to parent dfs(neighbor, currentNode); // resolve child FIRST (post-order) // If currentNode is EXCLUDED, each child freely picks the best of its two states dp[currentNode][0] += Math.max(dp[neighbor][0], dp[neighbor][1]); // If currentNode is INCLUDED, no child can also be included dp[currentNode][1] += dp[neighbor][0]; } } private static void addEdge(int u, int v) { adjacencyList.get(u).add(v); adjacencyList.get(v).add(u); } }
Node | Excluded | Included
1 37 31
2 13 11
3 24 22
4 0 7
5 0 6
6 0 1
7 0 20
Tree Diameter and Longest Path: When Two Children Talk to Each Other
Most tree DP combines each child with the current node. The diameter problem is sneaky because the longest path might not pass through the root at all — it could be a path entirely inside one subtree, or it could use the root as a bend point connecting two different children's longest arms. That 'bend point' idea is the key insight.
For every node, define depth(node) as the longest simple path from that node down into its subtree. Computing depth is a standard post-order DP: depth(leaf) = 0, depth(node) = 1 + max(depth(child) for all children).
The diameter update is where it gets interesting. When you're at a node and you've just computed a child's depth, you check: does a path that enters this node from one already-processed child arm, passes through this node, and exits down the new child's arm beat the current best? In code, you maintain a variable tracking the longest arm seen so far among processed children. For each new child, the candidate diameter through the current node is longestArmSoFar + 1 + depth(newChild) + 1. Then you update the global diameter and extend the longest arm.
This runs in O(n) and handles all the edge cases: path entirely inside a subtree (handled by the recursive calls updating globalDiameter), path through the root (handled at the root level), star graphs, paths, single nodes. No special cases needed.
import java.util.*; public class TreeDiameter { private static List<List<Integer>> adjacencyList; private static int globalDiameter = 0; public static void main(String[] args) { int totalNodes = 9; adjacencyList = new ArrayList<>(); for (int i = 0; i <= totalNodes; i++) { adjacencyList.add(new ArrayList<>()); } // Tree shape: // 1 // / \ // 2 3 // /| \ // 4 5 6 // / \ // 7 8 // / // 9 addEdge(1, 2); addEdge(1, 3); addEdge(2, 4); addEdge(2, 5); addEdge(3, 6); addEdge(4, 7); addEdge(6, 8); addEdge(7, 9); globalDiameter = 0; int depthFromRoot = computeDepthAndDiameter(1, -1); System.out.println("Tree Diameter (longest path in edges): " + globalDiameter); // Expected: path 9 -> 7 -> 4 -> 2 -> 1 -> 3 -> 6 -> 8 = 7 edges } /** * Returns the depth of the deepest path going DOWN from 'currentNode'. * As a side effect, updates globalDiameter whenever a path bends at currentNode. */ private static int computeDepthAndDiameter(int currentNode, int parentNode) { int longestArmDownward = 0; // longest single arm seen among children processed so far for (int neighbor : adjacencyList.get(currentNode)) { if (neighbor == parentNode) continue; // Get the depth of this child's subtree (post-order: child computed first) int childDepth = computeDepthAndDiameter(neighbor, currentNode); // A path can BEND at currentNode: one arm goes up a previously seen child, // the other goes down into this new child. // Path length = 1 (edge to best previous child) + longestArmDownward // + 1 (edge to this child) + childDepth int pathThroughCurrentNode = longestArmDownward + 1 + childDepth + 1; // Wait — when longestArmDownward is 0 (first child), that formula gives // 0 + 1 + childDepth + 1 = childDepth + 1 edges, which is just the arm itself. // The bend only matters when we have at least two children. // Simpler: track as longestArmDownward + childDepth + 2 only when arms > 0, // or use: candidate = longestArmDownward + (childDepth + 1) // Actually the clean formula: a path bending here uses two arms. // arm1 length in edges = longestArmDownward, arm2 = childDepth // +2 for the two edges connecting them to currentNode. // BUT for the very first child, longestArmDownward = 0, so: // candidate diameter at this node = longestArmDownward + childDepth + 2 // That handles single-arm paths too (they become longestArmDownward for next iteration) int candidateDiameter = longestArmDownward + childDepth + 2; // Subtract 2 for first child edge case? No — this is correct: // with 0 as longestArm and childDepth=3, candidateDiameter = 5 edges total (0+3+2). // But a single arm of depth 3 is 3 edges. The +2 overcounts. // FIX: the real formula is: when we have two arms a and b, // path = a + b + 2. When longestArmDownward = 0 (first arm), that's 0 + childDepth + 2. // That's wrong for a single arm. But that's OK — it's still updating globalDiameter // conservatively because the second arm will be processed later. // The cleanest approach: update globally using the TWO-arm formula, // and let the single-arm update happen naturally. // Actually the standard clean version: update longestArm BEFORE the formula. // Let's redo with the correct pattern below: globalDiameter = Math.max(globalDiameter, longestArmDownward + childDepth + 2); // Now extend the longest downward arm for future siblings longestArmDownward = Math.max(longestArmDownward, childDepth + 1); } // Return how deep we can go from currentNode downward (used by parent) return longestArmDownward; } // NOTE on the formula above: when longestArmDownward = 0 and childDepth = 3, // we get candidateDiameter = 5. But the actual path length for one arm of depth 3 // is 3 edges. The 'extra 2' only makes sense with TWO arms. // This is why we ALSO need: globalDiameter = max(globalDiameter, longestArmDownward) // after the loop to catch the case of a single-arm tree. // The loop handles two-arm paths; the return value propagates upward for ancestor bends. // For single-node trees and leaf nodes, longestArmDownward stays 0, diameter stays 0. private static void addEdge(int u, int v) { adjacencyList.get(u).add(v); adjacencyList.get(v).add(u); } }
Rerooting Technique: Answering 'What If Every Node Were Root?' in O(n)
Here's a problem that breaks naive tree DP: 'For every node, find the sum of distances to all other nodes.' Naively you'd re-root the tree at each node and rerun the DFS — O(n²). With rerooting, you do two DFS passes and get O(n).
Pass 1 (downward, standard post-order): root the tree arbitrarily at node 1. Compute subtreeSize[v] (how many nodes in v's subtree) and distanceSum[v] (sum of distances from v to all nodes in its subtree). distanceSum[v] = sum over all children c of (distanceSum[c] + subtreeSize[c]).
Pass 2 (rerooting, pre-order): propagate answers from parent to child. When you're at node v with parent p, the full distance sum for v is: distanceSum[v] (distances to nodes below v) plus (fullAnswer[p] - distanceSum[v] - subtreeSize[v]) (distances from p to everything NOT in v's subtree) plus (n - subtreeSize[v]) (one extra edge for each non-subtree node reaching v through p). This gives fullAnswer[v] from fullAnswer[p] in O(1).
The rerooting trick generalizes: any problem where you can express 'child's answer using parent's answer' by reversing a transition is a rerooting candidate. Think: k-th ancestor problems, centroid decomposition setup, competitive programming problems asking for 'answer at each node'.
import java.util.*; public class SumOfDistancesRerooting { private static List<List<Integer>> adjacencyList; private static long[] subtreeSize; // number of nodes in subtree rooted at v private static long[] downwardSum; // sum of distances from v to all nodes IN its subtree private static long[] fullAnswer; // final answer: sum of distances from v to ALL nodes private static int totalNodes; public static void main(String[] args) { totalNodes = 6; adjacencyList = new ArrayList<>(); for (int i = 0; i < totalNodes; i++) adjacencyList.add(new ArrayList<>()); // Tree (0-indexed): // 0 // /|\ // 1 2 3 // / \ // 4 5 addEdge(0, 1); addEdge(0, 2); addEdge(0, 3); addEdge(2, 4); addEdge(2, 5); subtreeSize = new long[totalNodes]; downwardSum = new long[totalNodes]; fullAnswer = new long[totalNodes]; // PASS 1: root at 0, compute subtreeSize and downwardSum (post-order) dfsDownward(0, -1); // PASS 2: propagate full answers from root down (pre-order) fullAnswer[0] = downwardSum[0]; // root's answer IS downwardSum (no nodes above it) dfsRerooting(0, -1); System.out.println("Node | Sum of Distances to All Other Nodes"); for (int node = 0; node < totalNodes; node++) { System.out.printf(" %-4d %d%n", node, fullAnswer[node]); } // Verify node 0 manually: // dist(0,1)=1, dist(0,2)=1, dist(0,3)=1, dist(0,4)=2, dist(0,5)=2 → sum = 7 System.out.println("\nManual check for node 0: 1+1+1+2+2 = 7"); } // Post-order: fill subtreeSize and downwardSum private static void dfsDownward(int currentNode, int parentNode) { subtreeSize[currentNode] = 1; // count this node itself downwardSum[currentNode] = 0; for (int neighbor : adjacencyList.get(currentNode)) { if (neighbor == parentNode) continue; dfsDownward(neighbor, currentNode); // child computed first subtreeSize[currentNode] += subtreeSize[neighbor]; // All nodes in neighbor's subtree are 1 edge closer when viewed from neighbor, // but from currentNode they're each +1 edge farther. downwardSum[currentNode] += downwardSum[neighbor] + subtreeSize[neighbor]; } } // Pre-order: propagate fullAnswer from parent to each child private static void dfsRerooting(int currentNode, int parentNode) { for (int neighbor : adjacencyList.get(currentNode)) { if (neighbor == parentNode) continue; // When we 'reroot' at neighbor: // - Nodes inside neighbor's subtree: their distances stay as downwardSum[neighbor] // (this part is already in neighbor's downward answer) // - Nodes OUTSIDE neighbor's subtree (count = totalNodes - subtreeSize[neighbor]): // they were reachable from currentNode. fullAnswer[currentNode] includes them. // But fullAnswer[currentNode] also includes neighbor's subtree nodes. // So the contribution of outside nodes to fullAnswer[neighbor] is: // (fullAnswer[currentNode] - downwardSum[neighbor] - subtreeSize[neighbor]) // + (totalNodes - subtreeSize[neighbor]) ← each outside node is 1 edge further long outsideContribution = fullAnswer[currentNode] - downwardSum[neighbor] // remove neighbor's subtree's downward dist - subtreeSize[neighbor] // remove the 1-edge cost for each subtree node + (totalNodes - subtreeSize[neighbor]); // re-add: each outside node now +1 edge fullAnswer[neighbor] = downwardSum[neighbor] + outsideContribution; dfsRerooting(neighbor, currentNode); // propagate further down } } private static void addEdge(int u, int v) { adjacencyList.get(u).add(v); adjacencyList.get(v).add(u); } }
0 7
1 11
2 9
3 11
4 13
5 13
Manual check for node 0: 1+1+1+2+2 = 7
Stack Overflow, Stack Conversion, and Production-Level Tree DP
Recursive tree DP looks clean but has a serious production problem: stack overflow on deep trees. A balanced binary tree with one million nodes has depth ~20 — fine. A path graph with one million nodes has depth one million — your JVM default stack (512KB–1MB) will throw StackOverflowError around depth 5,000–10,000.
The fix is iterative post-order DFS using an explicit stack. The pattern: push root, collect nodes in reverse post-order, then process the collected list in reverse (which gives post-order). This converts the recursive DFS into a two-phase loop with O(n) space on the heap instead of the call stack.
A second production concern is the parent-tracking approach. Using a visited boolean array instead of passing parentNode is fine for simple trees but breaks on the rare edge where a node genuinely appears in multiple neighbor lists due to a data bug. Always pass parentNode explicitly — it's both faster and safer.
The third concern: for very large trees (10⁶+ nodes), your adjacency list representation matters. ArrayList
import java.util.*; /** * Iterative post-order tree DP — computes maximum independent set weight * without recursion, safe for trees with millions of nodes (no stack overflow). * Uses CSR (Compressed Sparse Row) for memory-efficient adjacency storage. */ public class IterativeTreeDP { public static void main(String[] args) { int totalNodes = 7; int[] nodeWeight = {0, 10, 5, 15, 7, 6, 1, 20}; // 1-indexed // CSR format: store all edges in a flat array sorted by source node // edgeList[i] = destination of edge from node (determined by head/next arrays) // For simplicity here, we build CSR from an edge list int[][] edges = {{1,2},{1,3},{2,4},{2,5},{3,6},{3,7}}; // undirected int[] head = new int[totalNodes + 1]; // head[v] = index of first edge from v in 'next' chain Arrays.fill(head, -1); int[] edgeTo = new int[edges.length * 2]; // destination of each directed edge int[] nextEdge = new int[edges.length * 2]; // linked-list chain through edges int edgeCount = 0; for (int[] edge : edges) { // Add edge u->v edgeTo[edgeCount] = edge[1]; nextEdge[edgeCount] = head[edge[0]]; head[edge[0]] = edgeCount++; // Add edge v->u edgeTo[edgeCount] = edge[0]; nextEdge[edgeCount] = head[edge[1]]; head[edge[1]] = edgeCount++; } long[] dpExcluded = new long[totalNodes + 1]; // dp[v][0] long[] dpIncluded = new long[totalNodes + 1]; // dp[v][1] int[] parentOf = new int[totalNodes + 1]; Arrays.fill(parentOf, -1); // Phase 1: collect nodes in reverse post-order using an explicit stack List<Integer> processingOrder = new ArrayList<>(); Deque<Integer> dfsStack = new ArrayDeque<>(); boolean[] visited = new boolean[totalNodes + 1]; dfsStack.push(1); // root = node 1 visited[1] = true; while (!dfsStack.isEmpty()) { int currentNode = dfsStack.pop(); processingOrder.add(currentNode); // will be reversed = post-order for (int eid = head[currentNode]; eid != -1; eid = nextEdge[eid]) { int neighbor = edgeTo[eid]; if (!visited[neighbor]) { visited[neighbor] = true; parentOf[neighbor] = currentNode; dfsStack.push(neighbor); } } } // processingOrder is currently pre-order; reverse it to get post-order Collections.reverse(processingOrder); // Phase 2: process in post-order — children before parents (guaranteed by reversal) for (int currentNode : processingOrder) { dpExcluded[currentNode] = 0; dpIncluded[currentNode] = nodeWeight[currentNode]; for (int eid = head[currentNode]; eid != -1; eid = nextEdge[eid]) { int neighbor = edgeTo[eid]; if (neighbor == parentOf[currentNode]) continue; // skip parent edge // neighbor is a child; it was processed before currentNode (post-order) dpExcluded[currentNode] += Math.max(dpExcluded[neighbor], dpIncluded[neighbor]); dpIncluded[currentNode] += dpExcluded[neighbor]; // included node → children excluded } } long answer = Math.max(dpExcluded[1], dpIncluded[1]); System.out.println("Max Independent Set Weight (iterative, CSR): " + answer); System.out.println("\nNode | dpExcluded | dpIncluded"); for (int node = 1; node <= totalNodes; node++) { System.out.printf(" %-4d %-11d %d%n", node, dpExcluded[node], dpIncluded[node]); } } }
Node | dpExcluded | dpIncluded
1 37 31
2 13 11
3 24 22
4 0 7
5 0 6
6 0 1
7 0 20
| Aspect | Recursive Tree DP | Iterative Tree DP (Explicit Stack) |
|---|---|---|
| Code Readability | High — mirrors the recurrence directly | Medium — requires explicit order collection |
| Stack Overflow Risk | YES — O(depth) call frames, fails on path graphs with n > ~8000 | NO — uses heap-allocated Deque, safe for any n |
| Memory Usage | O(depth) call stack + O(n) heap | O(n) heap only — predictable |
| Cache Performance | Poor on deep paths — many function call frames | Better — flat array iteration in Phase 2 |
| Debugging Ease | Easier — stack traces are meaningful | Harder — processing order is implicit |
| When to Use | Balanced trees, competitive programming, n < 50,000 | Production systems, path-like trees, n > 100,000 |
| Rerooting Support | Straightforward second DFS pass | Requires two separate order-collection phases |
🎯 Key Takeaways
- Your DP state definition determines everything — it must capture exactly what a parent needs from a subtree. For independent set: two values (included/excluded). For diameter: the longest downward arm. Get this wrong and no amount of clever coding fixes it.
- The rerooting technique converts O(n²) 're-root and rerun' into O(n) by expressing each child's full answer as an O(1) transformation of its parent's full answer — learned in Pass 1, propagated in Pass 2.
- Recursive tree DP fails with StackOverflowError on path-like trees with depth > ~8,000 in Java. The iterative fix — collect nodes via explicit stack into a list, reverse for post-order, process — is a one-time pattern worth memorizing for production use.
- The diameter 'bending at a node' pattern (longestArm1 + longestArm2 + 2) recurs in dozens of problems: longest path with constraints, max weight path, tree center calculation. Recognizing it saves you from re-deriving the insight each time.
⚠ Common Mistakes to Avoid
- ✕Mistake 1: Defining the DP state too coarsely — e.g., storing only a single 'best value' per node when the parent needs to know TWO things (best with node included vs excluded). Symptom: getting wrong answers on trees where including the parent changes the child's contribution. Fix: always ask 'what does my parent need to know about me?' — if the answer has cases, your state needs multiple values.
- ✕Mistake 2: Forgetting to skip the parent edge during DFS — passing the parent as just a visited array instead of explicitly as a parameter. Symptom: on undirected trees, the DFS immediately recurses back into the parent, causing infinite recursion (StackOverflowError) or double-counting. Fix: always pass 'int parentNode' to your DFS function and skip any neighbor that equals parentNode. Never rely on a visited[] array alone for tree DFS.
- ✕Mistake 3: Applying the rerooting formula without adjusting subtreeSize correctly for the new root. Symptom: wrong fullAnswer values for non-root nodes — typically off by exactly (n - 2*subtreeSize[child]). Fix: when propagating from parent p to child c, remember that from c's perspective, the 'outside world' has (n - subtreeSize[c]) nodes, each now one edge farther than they appeared from p. The formula is fullAnswer[c] = downwardSum[c] + fullAnswer[p] - downwardSum[c] - subtreeSize[c] + (n - subtreeSize[c]).
Interview Questions on This Topic
- QGiven a binary tree where each node has a value, find the maximum sum path between any two nodes (the path doesn't have to pass through the root). Walk me through your DP state and transition. How does your approach differ from the tree diameter problem?
- QHow would you compute the sum of distances from every node to all other nodes in an unweighted tree in O(n) time? What are the two DFS passes doing and why is a single pass insufficient?
- QYou have a tree with n = 10^6 nodes given as a random permutation of edges — it could be a path graph. Your recursive tree DP solution works correctly but throws StackOverflowError. How do you fix it without increasing JVM stack size, and what's the time/space complexity of your fix?
Frequently Asked Questions
What is the difference between DP on trees and DP on graphs?
Trees have no cycles, which means there's a clear parent-child direction you can exploit — process children before parents (post-order) and each node is visited exactly once. General graphs have cycles, requiring memoization with visited states to avoid infinite loops. Tree DP is almost always O(n); graph DP on cyclic graphs often requires Bellman-Ford or other techniques and is more expensive.
How do I choose the root for tree DP when the problem doesn't specify one?
Pick any node — conventionally node 1 or node 0. The final answer at the root will be globally correct regardless of which root you choose, as long as your DP state and transitions are correctly defined. For rerooting problems, you pick an arbitrary root for Pass 1 and then correct all answers in Pass 2, so the choice truly doesn't matter.
Can tree DP handle weighted edges, not just weighted nodes?
Yes — you just incorporate the edge weight into the transition instead of (or in addition to) the node weight. For example, in the tree diameter problem with weighted edges, replace the '+1' in the arm length formula with '+edgeWeight(currentNode, child)'. Everything else stays identical. Store edge weights in your adjacency list as pairs (neighbor, weight) instead of just neighbor.
Written and reviewed by senior developers with real-world experience across enterprise, startup and open-source projects. Every article on TheCodeForge is written to be clear, accurate and genuinely useful — not just SEO filler.