Home DSA Matrix Chain Multiplication with Dynamic Programming — Deep Dive

Matrix Chain Multiplication with Dynamic Programming — Deep Dive

In Plain English 🔥
Imagine you're wrapping five birthday presents and you can combine them in any order using tape. Some orders waste almost no tape; others waste rolls of it — same presents, wildly different effort. Matrix Chain Multiplication is exactly that: you have a chain of matrices to multiply together, and the ORDER in which you pair them up changes the total number of arithmetic operations dramatically. DP finds the pairing order that does the least work.
⚡ Quick Answer
Imagine you're wrapping five birthday presents and you can combine them in any order using tape. Some orders waste almost no tape; others waste rolls of it — same presents, wildly different effort. Matrix Chain Multiplication is exactly that: you have a chain of matrices to multiply together, and the ORDER in which you pair them up changes the total number of arithmetic operations dramatically. DP finds the pairing order that does the least work.

Every time a graphics engine renders a scene, a robotics system computes joint transforms, or a machine learning framework chains layer activations, it multiplies a sequence of matrices. The math gives you the same answer no matter what order you parenthesize the chain — but the COST in floating-point operations can differ by orders of magnitude. A chain of just six matrices can have 42 different parenthesizations, and the worst choice can be thousands of times more expensive than the best. At scale, this is the difference between a real-time application and one that lags.

The brute-force approach — try every parenthesization and pick the cheapest — has a runtime that grows faster than exponential (it follows the Catalan number sequence). Matrix Chain Multiplication solved with Dynamic Programming collapses that to O(n³) time and O(n²) space by recognizing a crucial truth: the optimal solution to the full chain is built from optimal solutions to sub-chains. That overlapping-subproblem structure is the heartbeat of DP.

By the end of this article you'll be able to derive the recurrence relation from scratch, implement both a top-down memoized and a bottom-up tabulated solution in Java, read the parenthesization table to reconstruct the actual split points, reason about the space–time tradeoffs, and handle the edge cases that trip people up in interviews and production code.

Why the Multiplication Order Changes Everything — The Cost Model

Before writing a single line of DP, you need to own the cost model. Multiplying an (a × b) matrix by a (b × c) matrix yields an (a × c) matrix and costs exactly a × b × c scalar multiplications. That's it — that's the only formula you need.

Now consider three matrices: A (10×30), B (30×5), C (5×60). Two parenthesizations exist:

  • (AB)C: AB costs 10×30×5 = 1,500 ops → result is 10×5. Then (10×5)×C costs 10×5×60 = 3,000 ops. Total: 4,500.
  • A(BC): BC costs 30×5×60 = 9,000 ops → result is 30×60. Then A×(30×60) costs 10×30×60 = 18,000 ops. Total: 27,000.

Same three matrices, same final result, but option one is SIX TIMES cheaper. Now scale that intuition to 20 matrices in a deep learning pipeline and you see why this matters.

The dimension array is the key input. For n matrices, you encode their dimensions as an array of n+1 integers: dims[i] and dims[i+1] are the row and column counts of matrix i. This compact representation is not an accident — it's what lets the recurrence work cleanly.

CostModelDemo.java · JAVA
1234567891011121314151617181920212223242526272829303132333435
public class CostModelDemo {

    /**
     * Computes the scalar multiplication cost of multiplying
     * matrix A (rows x shared) by matrix B (shared x cols).
     * This is the atomic cost function the DP recurrence is built on.
     */
    public static long multiplicationCost(int rows, int shared, int cols) {
        return (long) rows * shared * cols; // cast to long — large dims overflow int
    }

    public static void main(String[] args) {
        // Three matrices: A(10x30), B(30x5), C(5x60)
        // Encoded as a dimension array: dims[i] x dims[i+1] = matrix i
        int[] dims = {10, 30, 5, 60};

        // Parenthesization 1: (A * B) * C
        long costAB   = multiplicationCost(dims[0], dims[1], dims[2]); // 10*30*5
        long costAB_C = multiplicationCost(dims[0], dims[2], dims[3]); // 10*5*60
        long total1   = costAB + costAB_C;

        // Parenthesization 2: A * (B * C)
        long costBC   = multiplicationCost(dims[1], dims[2], dims[3]); // 30*5*60
        long costA_BC = multiplicationCost(dims[0], dims[1], dims[3]); // 10*30*60
        long total2   = costBC + costA_BC;

        System.out.println("=== Cost Model Demo ===");
        System.out.printf("(A*B)*C  => AB=%,d  then (AB)*C=%,d  => Total: %,d%n",
                          costAB, costAB_C, total1);
        System.out.printf("A*(B*C)  => BC=%,d  then A*(BC)=%,d => Total: %,d%n",
                          costBC, costA_BC, total2);
        System.out.printf("Best parenthesization is %dx cheaper.%n",
                          total2 / total1);
    }
}
▶ Output
=== Cost Model Demo ===
(A*B)*C => AB=1,500 then (AB)*C=3,000 => Total: 4,500
A*(B*C) => BC=9,000 then A*(BC)=18,000 => Total: 27,000
Best parenthesization is 6x cheaper.
⚠️
Watch Out: Integer Overflow on Large DimensionsIf matrix dimensions reach the thousands (common in ML weight matrices), the product rows × shared × cols easily exceeds Integer.MAX_VALUE (~2.1B). Always accumulate costs into a long, not an int. The cast `(long) rows * shared * cols` must happen before any multiplication, not after — Java evaluates left-to-right in int arithmetic by default.

Deriving the DP Recurrence From Scratch — No Memorization Required

Here's how to derive the recurrence yourself instead of memorizing it. Ask: 'If I HAD to make exactly one top-level split in the chain from matrix i to matrix j, where would I put it?' The answer is some index k where i ≤ k < j. At that split, you'd multiply the left sub-chain (i to k) by the right sub-chain (k+1 to j). The cost of that top-level multiplication is dims[i] × dims[k+1] × dims[j+1]. The total cost is that plus the optimal costs of each sub-chain.

So: cost[i][j] = min over all k in [i, j-1] of (cost[i][k] + cost[k+1][j] + dims[i] × dims[k+1] × dims[j+1]).

Base case: cost[i][i] = 0 because a single matrix needs no multiplication.

This recurrence has OPTIMAL SUBSTRUCTURE — the optimal parenthesization of a chain contains optimal parenthesizations of its sub-chains. And it has OVERLAPPING SUBPROBLEMS — computing cost[0][5] will repeatedly ask for cost[0][2] and cost[3][5] across different values of k. Those two properties are exactly why DP applies here.

For bottom-up tabulation, you fill the table by increasing chain length (called 'span' or 'chain length l'), starting at l=2 (pairs of adjacent matrices) and growing to l=n (the full chain). The order matters — you need smaller sub-problems solved before larger ones.

MatrixChainDP.java · JAVA
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
import java.util.Arrays;

public class MatrixChainDP {

    /**
     * Solves Matrix Chain Multiplication using bottom-up DP (tabulation).
     *
     * @param dims  Array of n+1 dimensions where matrix i has size dims[i] x dims[i+1].
     * @return      A Result object containing the minimum cost and the split table.
     */
    public static Result minMultiplicationCost(int[] dims) {
        int n = dims.length - 1; // number of matrices

        // cost[i][j] = minimum scalar multiplications to compute matrices i..j
        long[][] cost = new long[n][n];

        // split[i][j] = the index k where the optimal top-level split occurs for chain i..j
        int[][] split = new int[n][n];

        // Base case: a single matrix costs nothing to "multiply"
        // Java initializes long[][] to 0, so base cases are already set.

        // Fill by increasing chain length (span) from 2 to n
        for (int span = 2; span <= n; span++) {

            // i is the start index of the sub-chain
            for (int i = 0; i <= n - span; i++) {
                int j = i + span - 1; // j is the end index

                cost[i][j] = Long.MAX_VALUE; // start with worst possible cost

                // Try every possible top-level split point k
                for (int k = i; k < j; k++) {
                    // Cost of multiplying left sub-chain (i..k) by right sub-chain (k+1..j)
                    // The resulting matrices have dims: dims[i] x dims[k+1] x dims[j+1]
                    long splitCost = cost[i][k]
                                   + cost[k + 1][j]
                                   + (long) dims[i] * dims[k + 1] * dims[j + 1];

                    if (splitCost < cost[i][j]) {
                        cost[i][j]  = splitCost; // record the better cost
                        split[i][j] = k;          // remember WHERE we split
                    }
                }
            }
        }

        return new Result(cost[0][n - 1], split, n);
    }

    /** Recursively builds the optimal parenthesization string from the split table. */
    static String buildParenthesization(int[][] split, int i, int j) {
        if (i == j) {
            return "M" + (i + 1); // matrices are 1-indexed for readability
        }
        int k = split[i][j];
        String left  = buildParenthesization(split, i, k);
        String right = buildParenthesization(split, k + 1, j);
        return "(" + left + " x " + right + ")";
    }

    public static class Result {
        public final long minCost;
        public final String parenthesization;

        Result(long minCost, int[][] split, int n) {
            this.minCost          = minCost;
            this.parenthesization = buildParenthesization(split, 0, n - 1);
        }
    }

    public static void main(String[] args) {
        // 6 matrices with these dimensions:
        // M1: 30x35, M2: 35x15, M3: 15x5, M4: 5x10, M5: 10x20, M6: 20x25
        int[] dims = {30, 35, 15, 5, 10, 20, 25};

        Result result = minMultiplicationCost(dims);

        System.out.println("=== Matrix Chain Multiplication (Bottom-Up DP) ===");
        System.out.printf("Minimum scalar multiplications: %,d%n", result.minCost);
        System.out.println("Optimal parenthesization: " + result.parenthesization);
    }
}
▶ Output
=== Matrix Chain Multiplication (Bottom-Up DP) ===
Minimum scalar multiplications: 15,125
Optimal parenthesization: ((M1 x (M2 x M3)) x ((M4 x M5) x M6))
🔥
Interview Gold: Why Fill by Span, Not by Row?Interviewers love asking why the outer loop iterates over chain length (span) rather than over row i. The answer is dependency order: cost[i][j] depends on cost[i][k] and cost[k+1][j], which are both shorter sub-chains. Filling by increasing span guarantees every dependency is resolved before it's needed. Filling row-by-row would try to use cost[k+1][j] before it's computed for larger j values.

Top-Down Memoization vs. Bottom-Up Tabulation — When to Choose Which

Both approaches give identical optimal costs, but they have meaningfully different characteristics in practice. Bottom-up tabulation (what we built above) fills the entire n×n table — even sub-problems that the optimal solution never actually needs. Top-down memoization only computes sub-problems that are reachable from the root query, which can be significantly fewer for sparse or structured chains.

For most competitive programming and interview contexts, bottom-up is preferred because it avoids recursion stack overhead and is easier to reason about iteratively. But in production systems where n can be large and the chain has known structure (e.g., always left-heavy or always right-heavy), top-down with memoization avoids unnecessary computation.

The top-down version stores results in the same long[n][n] table but initializes entries to -1 to distinguish 'not yet computed' from 'computed as zero'. The -1 sentinel is critical — miss it and you'll recompute sub-problems, destroying the O(n³) guarantee.

Both are O(n³) time and O(n²) space in the worst case. Neither is asymptotically better. The constant factors and cache behavior differ — tabulation tends to be more cache-friendly because it accesses memory in a predictable pattern.

MatrixChainMemo.java · JAVA
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
import java.util.Arrays;

public class MatrixChainMemo {

    private final int[] dims;
    private final long[][] memo;  // -1 means "not yet computed"
    private final int[][] split;  // stores optimal split points

    public MatrixChainMemo(int[] dims) {
        this.dims  = dims;
        int n      = dims.length - 1;
        this.memo  = new long[n][n];
        this.split = new int[n][n];

        // Initialize all entries to -1 (our "unvisited" sentinel)
        for (long[] row : memo) Arrays.fill(row, -1L);
    }

    /**
     * Recursively computes the minimum cost to multiply matrices i through j.
     * Results are cached in memo[][] to avoid redundant computation.
     */
    public long solve(int i, int j) {
        // Base case: a single matrix costs nothing
        if (i == j) return 0L;

        // Return cached result if already computed
        if (memo[i][j] != -1L) return memo[i][j];

        memo[i][j] = Long.MAX_VALUE; // sentinel before we find the minimum

        for (int k = i; k < j; k++) {
            long candidate = solve(i, k)
                           + solve(k + 1, j)
                           + (long) dims[i] * dims[k + 1] * dims[j + 1];

            if (candidate < memo[i][j]) {
                memo[i][j]  = candidate;
                split[i][j] = k; // track optimal split for reconstruction
            }
        }

        return memo[i][j];
    }

    public int[][] getSplitTable() { return split; }

    public static void main(String[] args) {
        // Same 6-matrix chain as the tabulation example — results must match
        int[] dims = {30, 35, 15, 5, 10, 20, 25};

        MatrixChainMemo solver = new MatrixChainMemo(dims);
        int n        = dims.length - 1;
        long minCost = solver.solve(0, n - 1);
        String parens = MatrixChainDP.buildParenthesization(solver.getSplitTable(), 0, n - 1);

        System.out.println("=== Matrix Chain Multiplication (Top-Down Memoization) ===");
        System.out.printf("Minimum scalar multiplications: %,d%n", minCost);
        System.out.println("Optimal parenthesization: " + parens);

        // Count how many unique sub-problems were actually evaluated
        long computed = 0;
        for (long[] row : solver.memo)
            for (long val : row)
                if (val != -1L) computed++;

        System.out.printf("Sub-problems computed: %d out of %d possible%n",
                          computed, n * n);
    }
}
▶ Output
=== Matrix Chain Multiplication (Top-Down Memoization) ===
Minimum scalar multiplications: 15,125
Optimal parenthesization: ((M1 x (M2 x M3)) x ((M4 x M5) x M6))
Sub-problems computed: 21 out of 36 possible
⚠️
Pro Tip: -1 Is Only a Safe Sentinel If Costs Are Non-NegativeThe -1 sentinel works here because all multiplication costs are ≥ 0, so a legitimate computed answer can never be -1. If you ever adapt this pattern to a problem where negative costs are possible (like some weighted graph problems), use a separate boolean[][] visited array instead of relying on a sentinel value in the cost table.

Edge Cases, Production Gotchas, and the Real Interview Traps

The textbook version handles the happy path. Here's what breaks in the real world.

SINGLE MATRIX: n=1 means dims has exactly 2 elements. The loop for (int span = 2; span <= n; span++) never executes, and cost[0][0] stays 0. Correct — but make sure your calling code doesn't index cost[0][1] which doesn't exist.

DIMENSION MISMATCH: Your dims array must satisfy dims.length == numberOfMatrices + 1. If someone passes a malformed array (e.g., forgetting that matrix i's column count must equal matrix i+1's row count), the cost computation is meaningless. Validate this at the entry point.

NUMERIC SCALE: With 20 matrices each of dimension 1000, a single split costs up to 1000³ = 10⁹ ops, and summing across O(n²) sub-problems pushes totals well above Long.MAX_VALUE (~9.2×10¹⁸) only when dimensions reach absurd values. For practical dims (≤10,000), long is safe.

RECONSTRUCTION BUG: The split table is only meaningful for i < j. Accessing split[i][i] returns 0 (Java default), which looks like a valid split point. Always gate reconstruction with the i == j base case FIRST, before reading split[i][j].

NOT JUST ABOUT SCALARS: Real matrix libraries (BLAS, NumPy, PyTorch) have non-trivial constant factors per operation. MCM gives you the asymptotically optimal order, but cache effects and SIMD vectorization mean the 'optimal' order in terms of op count isn't always fastest in wall-clock time. Profile before over-optimizing.

MatrixChainRobust.java · JAVA
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
public class MatrixChainRobust {

    /**
     * Validates the dimension array before running DP.
     * Catches common mistakes before they silently produce wrong answers.
     */
    public static void validateDimensions(int[] dims) {
        if (dims == null || dims.length < 2) {
            throw new IllegalArgumentException(
                "dims must have at least 2 elements (representing at least 1 matrix). " +
                "Got: " + (dims == null ? "null" : dims.length));
        }
        for (int i = 0; i < dims.length; i++) {
            if (dims[i] <= 0) {
                throw new IllegalArgumentException(
                    "All dimensions must be positive. dims[" + i + "] = " + dims[i]);
            }
        }
        // Note: We can't validate inner dimension compatibility here because the
        // dims array implicitly encodes that dims[i] == column count of matrix i-1
        // == row count of matrix i. The format itself guarantees compatibility.
    }

    public static long minCostSafe(int[] dims) {
        validateDimensions(dims);

        int n = dims.length - 1;

        // Edge case: single matrix needs zero multiplications
        if (n == 1) return 0L;

        long[][] cost = new long[n][n];

        for (int span = 2; span <= n; span++) {
            for (int i = 0; i <= n - span; i++) {
                int j = i + span - 1;
                cost[i][j] = Long.MAX_VALUE;

                for (int k = i; k < j; k++) {
                    // Guard against overflow: check before adding
                    long joinCost = (long) dims[i] * dims[k + 1] * dims[j + 1];

                    // If either sub-problem somehow hit MAX_VALUE, skip (defensive)
                    if (cost[i][k] == Long.MAX_VALUE || cost[k + 1][j] == Long.MAX_VALUE) {
                        continue;
                    }

                    long total = cost[i][k] + cost[k + 1][j] + joinCost;
                    if (total < cost[i][j]) {
                        cost[i][j] = total;
                    }
                }
            }
        }
        return cost[0][n - 1];
    }

    public static void main(String[] args) {
        System.out.println("=== Robust MCM with Edge Case Handling ===");

        // Edge case 1: Single matrix
        int[] singleMatrix = {100, 200};
        System.out.println("Single matrix cost: " + minCostSafe(singleMatrix)); // 0

        // Edge case 2: Two matrices (only one way to multiply)
        int[] twoMatrices = {10, 20, 30};
        System.out.println("Two matrices cost: " + minCostSafe(twoMatrices)); // 6000

        // Edge case 3: Very wide vs very tall — order matters enormously
        int[] extremeDims = {1, 1000, 1, 1000}; // M1(1x1000), M2(1000x1), M3(1x1000)
        System.out.println("Extreme dims cost: " + minCostSafe(extremeDims)); // 2000

        // Edge case 4: Bad input — caught before DP runs
        try {
            minCostSafe(new int[]{10, -5, 20});
        } catch (IllegalArgumentException e) {
            System.out.println("Caught: " + e.getMessage());
        }
    }
}
▶ Output
=== Robust MCM with Edge Case Handling ===
Single matrix cost: 0
Two matrices cost: 6000
Extreme dims cost: 2000
Caught: All dimensions must be positive. dims[1] = -5
⚠️
Watch Out: The Split Table Has Undefined Cellssplit[i][i] is never written by the DP — Java initializes it to 0, which happens to be a valid matrix index. If your reconstruction code reads split[i][i] instead of returning early on the i==j base case, you'll get infinite recursion or a silently wrong parenthesization. Always check i == j BEFORE accessing split[i][j].
AspectBottom-Up TabulationTop-Down Memoization
Time ComplexityO(n³) — always fills all cellsO(n³) worst case, often less
Space ComplexityO(n²) for the DP tableO(n²) table + O(n) call stack
Sub-problems computedAll n² cells, even unused onesOnly reachable sub-problems
Stack overflow riskNone — purely iterativeYes, for very large n (n > ~5000)
Cache friendlinessHigh — sequential row accessLower — unpredictable access order
Code clarityTrickier loop ordering to get rightCloser to the recursive definition
Best forInterviews, dense chains, all nSparse chains, structured inputs
ReconstructionRead split[][] after loopRead split[][] after solve() returns

🎯 Key Takeaways

  • The cost of multiplying (a×b) by (b×c) is a×b×c — that single formula IS the entire problem; DP just finds which order of applying it is cheapest.
  • Fill the DP table by increasing chain span (length 2, then 3, … then n) — not by row. The dependency direction makes this non-negotiable.
  • Always use long for cost accumulation — int overflow in MCM is silent and produces wrong answers that look plausible, making it one of the nastiest bugs to debug.
  • The split table is as important as the cost table: minimum cost tells you HOW cheap; split tells you HOW to achieve it. In production, you need both.

⚠ Common Mistakes to Avoid

  • Mistake 1: Iterating the outer DP loop over i instead of span (chain length) — Symptom: cost[i][j] reads from cost[k+1][j] cells that haven't been filled yet, producing incorrect (often zero or MAX_VALUE) costs silently — Fix: always make the outermost loop for (int span = 2; span <= n; span++) so all sub-problems of length span-1 are solved before any of length span.
  • Mistake 2: Using int instead of long for accumulated costs — Symptom: for matrices with dimensions in the hundreds, costs silently overflow int and wrap to negative values, making those split points look artificially cheap and giving a completely wrong answer with no exception thrown — Fix: declare cost[][] as long[][] and cast immediately before multiplication: (long) dims[i] dims[k+1] dims[j+1].
  • Mistake 3: Confusing the number of matrices (n) with the length of the dims array (n+1) — Symptom: an ArrayIndexOutOfBoundsException on dims[j+1] when j reaches n-1, or missing the last matrix entirely if you loop to dims.length-2 instead of dims.length-1 — Fix: always derive n = dims.length - 1 at the start and treat n as the matrix count; the j+1 index in the cost formula is safe because j goes up to n-1, so j+1 = n which is a valid dims index.

Interview Questions on This Topic

  • QWhy does Matrix Chain Multiplication have optimal substructure, and how is that different from a greedy approach? Can you think of a greedy heuristic that fails here?
  • QWalk me through the recurrence relation — specifically, why does the join cost use dims[i] × dims[k+1] × dims[j+1] and not dims[i] × dims[j]?
  • QIf I give you a chain of matrices that I know are always multiplied left-to-right in production, does MCM still help? What's the actual complexity saving compared to left-to-right evaluation?

Frequently Asked Questions

What is the time complexity of Matrix Chain Multiplication using dynamic programming?

O(n³) time and O(n²) space, where n is the number of matrices. There are O(n²) sub-problems (all pairs i,j) and each takes O(n) to solve by trying every split point k. This is a massive improvement over brute force, which grows with the Catalan numbers — roughly 4ⁿ/n^(3/2).

Does Matrix Chain Multiplication change the result of the multiplication, or only the speed?

Only the speed. Matrix multiplication is associative, meaning (AB)C = A(BC) always produces the same matrix. MCM exploits this to reorder the computation without changing the mathematical result. It finds the parenthesization that minimizes scalar multiplications.

Why can't we use a greedy approach — like always multiplying the two cheapest adjacent matrices first?

Because locally cheap choices can create expensive intermediate matrices that blow up later costs. For example, merging two tiny adjacent matrices first might produce a massive intermediate that then dominates all remaining multiplications. DP avoids this by considering ALL possible splits globally, not just the cheapest local one.

🔥
TheCodeForge Editorial Team Verified Author

Written and reviewed by senior developers with real-world experience across enterprise, startup and open-source projects. Every article on TheCodeForge is written to be clear, accurate and genuinely useful — not just SEO filler.

← PreviousCoin Change ProblemNext →Edit Distance Problem
Forged with 🔥 at TheCodeForge.io — Where Developers Are Forged