Curious Insight


Technology, software, data science, machine learning, entrepreneurship, investing, and various other topics

Tags


Curious Insight

Solving Problems With Dynamic Programming

9th May 2016

Dynamic programming is a really useful general technique for solving problems that involves breaking down problems into smaller overlapping sub-problems, storing the results computed from the sub-problems and reusing those results on larger chunks of the problem. Dynamic programming solutions are pretty much always more efficent than naive brute-force solutions. It's particularly effective on problems that contain optimal substructure.

Dynamic programming is related to a number of other fundamental concepts in computer science in interesting ways. Recursion, for example, is similar to (but not identical to) dynamic programming. The key difference is that in a naive recursive solution, answers to sub-problems may be computed many times. A recursive solution that caches answers to sub-problems which were already computed is called memoization, which is basically the inverse of dynamic programming. Another variation is when the sub-problems don't actually overlap at all, in which case the technique is known as divide and conquer. Finally, dynamic programming is tied to the concept of mathematical induction and can be thought of as a specific application of inductive reasoning in practice.

While the core ideas behind dynamic programming are actually pretty simple, it turns out that it's fairly challenging to use on non-trivial problems because it's often not obvious how to frame a difficult problem in terms of overlapping sub-problems. This is where experience and practice come in handy, which is the idea for this blog post. We'll build both naive and "intelligent" solutions to several well-known problems and see how the problems are decomposed to use dynamic programming solutions. The code is written in basic python with no special dependencies.

Fibonacci Numbers

First we'll look at the problem of computing numbers in the Fibonacci sequence. The problem definition is very simple - each number in the sequence is the sum of the two previous numbers in the sequence. Or, more formally:

\(F_n = F_{n-1} + F_{n-2}\), with \(F_0 = 0\) and \(F_1 = 1\) as the seed values.

Our solution will be responsible for calculating each of the Fibonacci numbers up to some defined limit. We'll first implement a naive solution that re-calculates each number in the sequence from scratch.

def fib(n):
    if n == 0:
        return 0
    if n == 1:
        return 1
    
    return fib(n - 1) + fib(n - 2)

def all_fib(n):
    fibs = []
    for i in range(n):
        fibs.append(fib(i))
    
    return fibs

Let's try it out on a pretty small number first.

%time print(all_fib(10))

[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
Wall time: 0 ns

Okay, probably too trivial. Let's try a bit bigger...

%time print(all_fib(20))

[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181]
Wall time: 5 ms

The runtime was at least measurable now, but still pretty quick. Let's try one more time...

%time print(all_fib(40))

[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765, 10946, 17711, 28657, 46368, 75025, 121393, 196418, 317811, 514229, 832040, 1346269, 2178309, 3524578, 5702887, 9227465, 14930352, 24157817, 39088169, 63245986]
Wall time: 1min 9s

That escalated quickly! Clearly this is a pretty bad solution. Let's see what it looks like when applying dynamic programming.

def all_fib_dp(n):
    fibs = []
    for i in range(n):
        if i < 2:
            fibs.append(i)
        else:
            fibs.append(fibs[i - 2] + fibs[i - 1])
    
    return fibs

This time we're saving the result at each iteration and computing new numbers as a sum of the previously saved results. Let's see what this does to the performance of the function.

%time print(all_fib_dp(40))

[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765, 10946, 17711, 28657, 46368, 75025, 121393, 196418, 317811, 514229, 832040, 1346269, 2178309, 3524578, 5702887, 9227465, 14930352, 24157817, 39088169, 63245986]
Wall time: 0 ns

By not computing the full recusrive tree on each iteration, we've essentially reduced the running time for the first 40 numbers from ~75 seconds to virtually instant. This also happens to be a good example of the danger of naive recursive functions. Our new Fibonaci number function can compute additional values in linear time vs. exponential time for the first version.

%time print(all_fib_dp(100))

[0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765, 10946, 17711, 28657, 46368, 75025, 121393, 196418, 317811, 514229, 832040, 1346269, 2178309, 3524578, 5702887, 9227465, 14930352, 24157817, 39088169, 63245986, 102334155, 165580141, 267914296, 433494437, 701408733, 1134903170, 1836311903, 2971215073L, 4807526976L, 7778742049L, 12586269025L, 20365011074L, 32951280099L, 53316291173L, 86267571272L, 139583862445L, 225851433717L, 365435296162L, 591286729879L, 956722026041L, 1548008755920L, 2504730781961L, 4052739537881L, 6557470319842L, 10610209857723L, 17167680177565L, 27777890035288L, 44945570212853L, 72723460248141L, 117669030460994L, 190392490709135L, 308061521170129L, 498454011879264L, 806515533049393L, 1304969544928657L, 2111485077978050L, 3416454622906707L, 5527939700884757L, 8944394323791464L, 14472334024676221L, 23416728348467685L, 37889062373143906L, 61305790721611591L, 99194853094755497L, 160500643816367088L, 259695496911122585L, 420196140727489673L, 679891637638612258L, 1100087778366101931L, 1779979416004714189L, 2880067194370816120L, 4660046610375530309L, 7540113804746346429L, 12200160415121876738L, 19740274219868223167L, 31940434634990099905L, 51680708854858323072L, 83621143489848422977L, 135301852344706746049L, 218922995834555169026L]
Wall time: 0 ns

Longest Increasing Subsequence

The Fibonacci problem is a good starter example but doesn't really capture the challenge of representing problems in terms of optimal sub-problems because for Fibonacci numbers the answer is pretty obvious. Let's move up one step in difficulty to a problem known as the longest increasing subsequence problem. The objective is to find the longest subsequence of a given sequence such that all elements in the subsequence are sorted in increasing order. Note that the elements do not need to be contiguous; that is, they are not required to appear next to each other. For example, in the sequence [10, 22, 9, 33, 21, 50, 41, 60, 80] the longest increasing subsequence (LIS) is [10, 22, 33, 50, 60, 80].

It turns out that it's fairly difficult to do a "brute-force" solution to this problem. The dynamic programming solution is much more concise and a natural fit for the problem definition, so we'll skip creating an unnecessarily complicated naive solution and jump straight to the DP solution.

def find_lis(seq):
    n = len(seq)
    max_length = 1
    best_seq_end = -1
    
    # keep a chain of the values of the lis
    prev = [0 for i in range(n)]
    prev[0] = -1
    
    # the length of the lis at each position
    length = [0 for i in range(n)]
    length[0] = 1
    
    for i in range(1, n):
        length[i] = 0
        prev[i] = -1
        
        # start from index i-1 and work back to 0
        for j in range(i - 1, -1, -1):
            if (length[j] + 1) > length[i] and seq[j] < seq[i]:
                # there's a number before position i that increases the lis at i
                length[i] = length[j] + 1
                prev[i] = j
        
        if length[i] > max_length:
            max_length = length[i]
            best_seq_end = i
    
    # recover the subsequence
    lis = []
    element = best_seq_end
    while element != -1:
        lis.append(seq[element])
        element = prev[element]
    
    return lis[::-1]

The intuition here is that for a given index \(i\), we can compute the length of the longest increasing subsequence \(length(i)\) by looking at all indices \(j < i\) and if \(length(j) + 1 > i\) and \(seq[j] < seq[i]\) (meaning there's a number at position \(j\) that increases the longest subsequence at that index such that it is now longer than the longest recorded subsequence at \(i\)) then we increase \(length(i)\) by 1. It's a bit confusing at first glance but step through it carefully and convince yourself that this solution finds the optimal subsequence. The "prev" list holds the indices of the elements that form the actual values in the subsequence.

Let's generate some test data and try it out.

import numpy as np
seq_small = list(np.random.randint(0, 20, 20))
seq_small

[16, 10, 17, 18, 9, 0, 2, 19, 4, 3, 1, 14, 12, 6, 2, 4, 11, 5, 19, 4]

Now we can run a quick test to see if it works on a small sequence.

%time print(find_lis(seq_small))

[0, 1, 2, 4, 5, 19]
Wall time: 0 ns

Just based on the eye test the output looks correct. Let's see how well it performs on much larger sequences.

seq = list(np.random.randint(0, 10000, 10000))
%time print(find_lis(seq))

[29, 94, 125, 159, 262, 271, 274, 345, 375, 421, 524, 536, 668, 689, 694, 755, 763, 774, 788, 854, 916, 1018, 1022, 1098, 1136, 1154, 1172, 1237, 1325, 1361, 1400, 1401, 1406, 1450, 1498, 1633, 1693, 1745, 1765, 1793, 1835, 1949, 1997, 2069, 2072, 2096, 2157, 2336, 2345, 2468, 2519, 2529, 2624, 2630, 2924, 3103, 3291, 3321, 3380, 3546, 3635, 3657, 3668, 3703, 3775, 3836, 3850, 3961, 4002, 4004, 4039, 4060, 4128, 4361, 4377, 4424, 4432, 4460, 4465, 4493, 4540, 4595, 4693, 4732, 4735, 4766, 4831, 4850, 4873, 4908, 4940, 4969, 5013, 5073, 5087, 5139, 5144, 5271, 5280, 5299, 5300, 5355, 5393, 5430, 5536, 5538, 5559, 5565, 5822, 5891, 5895, 5906, 6157, 6199, 6286, 6369, 6431, 6450, 6510, 6533, 6577, 6585, 6683, 6701, 6740, 6745, 6829, 6853, 6863, 6872, 6884, 6923, 6925, 7009, 7019, 7028, 7040, 7170, 7235, 7304, 7356, 7377, 7416, 7490, 7495, 7662, 7676, 7703, 7808, 7925, 7971, 8036, 8073, 8282, 8295, 8332, 8342, 8360, 8429, 8454, 8499, 8557, 8585, 8639, 8649, 8725, 8759, 8831, 8860, 8899, 8969, 9046, 9146, 9161, 9245, 9270, 9374, 9451, 9465, 9515, 9522, 9525, 9527, 9664, 9770, 9781, 9787, 9914, 9993]
Wall time: 4.94 s

So it's still pretty fast, but the difference is definitely noticable. At 10,000 integers in the sequence our algorithm already takes several seconds to complete. In fact, even though this solution uses dynamic programming its runtime is still \(O(n^2)\). The lesson here is that dynamic programming doesn't always result in lightning-fast solutions. There are also different ways to apply DP to the same problem. In fact there's a solution to this problem that uses binary search trees and runs in \(O(nlogn)\) time, significantly better than the solution we just came up with.

Knapsack Problem

The knapsack problem is another classic dynamic programming exercise. The generalization of this problem is very old and comes in many variations, and there are actually multiple ways to tackle this problem aside from dynamic programming. Still, it's a common example for DP exercises.

The problem at its core is one of combinatorial optimization. Given a set of items, each with a mass and a value, determine the collection of items that results in the highest possible value while not exceeding some limit on the total weight. The variation we'll look at is commonly referred to as the 0-1 knapsack problem, which restricts the number of copies of each kind of item to 0 or 1. More formally, given a set of \(n\) items each with weight \(w_i\) and value \(v_i\) along with a maximum total weight \(W\), our objective is:

\(\Large max \Sigma v_i x_i\), where \(\Large \Sigma w_i x_i \leq W\)

Let's see what the implementation looks like then discuss why it works.

def knapsack(W, w, v):
    # create a W x n solution matrix to store the sub-problem results
    n = len(v)
    S = [[0 for x in range(W)] for k in range(n)]
    
    for x in range(1, W):
        for k in range(1, n):
            # using this notation k is the number of items in the solution and x is the max weight of the solution,
            # so the initial assumption is that the optimal solution with k items at weight x is at least as good
            # as the optimal solution with k-1 items for the same max weight
            S[k][x] = S[k-1][x]
            
            # if the current item weighs less than the max weight and the optimal solution including this item is 
            # better than the current optimum, the new optimum is the one resulting from including the current item
            if w[k] < x and S[k-1][x-w[k]] + v[k] > S[k][x]:
                S[k][x] = S[k-1][x-w[k]] + v[k]
    
    return S

The intuition behind this algorithm is that once you've solved for the optimal combination of items at some weight \(x < W\) and with some number of items \(k < n\), then it's easy to solve the problem with one more item or one higher max weight because you can just check to see if the solution obtained by incorporating that item is better than the best solution you've already found. So how do you get the initial solution? Keep going down the rabbit hole until to reach 0 (in which case the answer is 0). At first glance it's very hard to grasp, but that's part of the magic of dynamic programming. Let's run an example to see what it looks like. We'll start with some randomly-generated weights and values.

w = list(np.random.randint(0, 10, 5))
v = list(np.random.randint(0, 100, 5))
w, v

([3, 9, 3, 6, 5], [40, 45, 72, 77, 16])

Now we can run the algorithm with a constraint that the weights of the items can't add up to more than 15.

knapsack(15, w, v)

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 45, 45, 45, 45],
 [0, 0, 0, 0, 72, 72, 72, 72, 72, 72, 72, 72, 72, 117, 117],
 [0, 0, 0, 0, 72, 72, 72, 77, 77, 77, 149, 149, 149, 149, 149],
 [0, 0, 0, 0, 72, 72, 72, 77, 77, 88, 149, 149, 149, 149, 149]]

The output here is the array of optimal values for a given max weight (think of it as the column index) and max number of items (the row index). Notice how the output follows what looks sort of like a wavefront pattern. This seems to be a recurring phenomenon with dynamic programming solutions. The value in the lower right corner is the max value that we were looking for under the given constraints and is the answer to the problem.

That concludes our introduction to dynamic programming! Using this technique in the real world definitely requires a lot of practice; most applications of dynamic programming are not very obvious and take some skill to discover. Personally it doesn't come naturally to me at all and even learning these relatively simple examples took quite a bit of thought. It might seem like these sorts of problems don't come up all that often in practice, and there's probably some truth to that. However I've found that simply knowing about dynamic programming and how it fits into a more general problem-solving framework has made me a better engineer, and that in of itself makes it worth the time investment to understand.

Follow me on twitter to get new post updates.



Software Development

Data scientist, engineer, author, investor, entrepreneur