Calculating binominal coefficients

NOTE: This article requires MathML, which is not supported by IE. Use Firefox or Safari to visit this page...

In math, the binominal coefficient is defined as n k = n ! k ! n k !

Find an efficient way to compute this.

Ask yourself: Do I understand the problem?

Setting up a test harness

As usual, our first step is to prepare a testbed. Our code strategy for the test will be this:

Therefor, our test code looks like this:

#! -*- Encoding: Latin-1 -*-

def selftest():
    import time
    
    methods = { 'using brute force' : bc_brute_force,
                'using dynamic programming' : bc_dynamic_programming,
                ... }
                
    expected_result = None
    for method_name in methods:
        method = methods[method_name]
        
        t0 = time.time()
        result = 0
        for n in range(1000,1010):
            for k in range(5,10):
                result ^= method(n, k)
        print "%30s: took %.2f for result %r" % (method_name, time.time() - t0, result, )
        if expected_result is None:
            expected_result = result
        else:
            assert result == expected_result
            
if __name__ == "__main__":
    selftest()

This code doesn't compare a single result, it xors all results and compares those. But that is just a small side show for the main algorithm.

Brute-Force approach

Let's try brute-force first: the mathematical formula at the top is pretty straightforward: calculate three factorials, and then evaluate the term

def fac(n):
    result = 1
    for i in xrange(2, n+1):
        result *= i
    return result
    
def bc_brute_force(n, k):
    return fac(n) / (fac(k) * fac(n-k))

This kind of begs the question: what is the fastest factorial function here? Let's briefly look at three alternatives:

def fac1(n):
    result = 1
    for i in range(2, n+1):
        result *= i
    return result
    
def fac2(n):
    return reduce(operator.mul, xrange(1, n+1))
    
def fac3(n):
    result = 1
    i = n
    while i >= 2:
        result *= i
        i -= 1
    return result
    
# quick testbed: evaluating the best performance of the three factorial functions
for func in (fac1, fac2, fac3, ):
    t0 = time.time()
    func(100000)
    print func, time.time() - t0

Quick: guess which one is the fastest? I for one was surprised:

<function fac1 at 0x01F2D230> 11.4319999218
<function fac2 at 0x0237B9F0> 11.4539999962
<function fac3 at 0x0237B8F0>  9.6970000267

It turns out that the first functions, while being pythonic and all, allocate too much memory for their own good. So fac3 is what we're going to use for slightly optimized brute-force:

def fac(n):
    result = 1
    i = n
    while i >= 2:
        result *= i
        i -= 1
    return result
    
def bc_brute_force(n, k):
    return fac(n) / (fac(k) * fac(n-k))

Some notes on this implementation:

Using Dynamic Programming

Well, if you look into textbooks, then our factorial function is already using dynamic programming. But we can improve on that, noticing that the original definition asks us to calculate three things:

That means that as we "visit" each number to calculate n!, we might as well calculate k! and (n-k)! along with it, thus making one call out of three.

def bc_dynamic_programming(n, k):
    n_fak = 1
    k_fak = None
    n_minus_k_fak = None
    n_minus_k = n - k
    
    for x in range(1, n+1):
        n_fak *= x
        if x == k:
            k_fak = n_fak
        
        if x == n_minus_k:
            n_minus_k_fak = n_fak
            
    return n_fak / (k_fak * n_minus_k_fak)

Next stop: algebra

Take a look at the definition of the factorial again.

n k = n ! k ! n k ! = 1 × 2 × ... × n 1 × 2 × ... × k   1 × 2 × ... × (n-k) = 1 × 2 × ... × k   (k+1) × (k+2) × ... × n 1 × 2 × ... × k   1 × 2 × ... × (n-k) = (k+1) × (k+2) × ... × n 1 × 2 × ... × (n-k)

>You see that you actually can save a couple of multiplications this way. Now, there is an alternative way to optimize the expression:

n k = n ! k ! n k ! = 1 × 2 × ... × n 1 × 2 × ... × k   1 × 2 × ... × (n-k) = 1 × 2 × ... × (n-k)   (n-k+1) × (n-k+2) × ... × n 1 × 2 × ... × k   1 × 2 × ... × (n-k) = (n-k+1) × (n-k+2) × ... × n 1 × 2 × ... × k

Think about the question: when does it make sense to use either representation? It turns out that if k > n/2, then the term 1 * 2 * ... * k is longer than the term 1 * 2 * ... * (n-k), so you would want to use the first representation in that case. And if k < n/2, then the second expression wins.

Armed with that, let's try to write code for it:

def bc_optim(n, k):
    
    if k > n/2:
        upper_range = range(k+1,n+1)
        lower_range = range(2, (n-k+1))
        
    else:
        upper_range = range(n-k+1,n+1)
        lower_range = range(2, k+1)
        
    u = 1
    for dividend in upper_range:
        u *= dividend

    l = 1
    for divisor in lower_range:
        l *= divisor
        
    return u/l

So, was all that math worth it? Let's look at our testbed and compare the results:

             using brute force: took 1.56 for result 4423137553666283105672019847538510671575080839179707002338085191579938072287668281631536191437946606118381203890479389237995287646214963339276827695049847636402807681032197306196418310999689640422676267452377523863414743729520994673266549091320284038814420240634986799913936602081779632146004368166601L
     using dynamic programming: took 0.92 for result 4423137553666283105672019847538510671575080839179707002338085191579938072287668281631536191437946606118381203890479389237995287646214963339276827695049847636402807681032197306196418310999689640422676267452377523863414743729520994673266549091320284038814420240634986799913936602081779632146004368166601L
          using optimized code: took 0.43 for result 4423137553666283105672019847538510671575080839179707002338085191579938072287668281631536191437946606118381203890479389237995287646214963339276827695049847636402807681032197306196418310999689640422676267452377523863414743729520994673266549091320284038814420240634986799913936602081779632146004368166601L

So it turns out that that last function indeed is a significant improvement over the first - even though by the rules of Big-Oh, nothing has changed, really: complexity is still O(N), but the constant factors are much better now...

Caring about overflow: AKA Pascals Triangle

Now, in the very fine Elements of Programming Interviews problem 12.14 asks to design an efficient algorithm for binonminal coefficients - we did that - but one that doesn't overflow.

Well duh. Python doesn't overflow: integers are implemented in a bignum library and cannot overflow like that. But there are poor people out there forced to write Java code, or C/C++ code (well C/C++ has its merits, see here), and there integers are limited to 32 or 64 bit. So how can we help those poor chaps and gals?

The reference solution in the book works basically like this:

OK, let's calculate the coefficient using pascals triangle. The code is actually very simple - much more simple than our previous suggestions:

def bc_pascal_triangle(n, k):
    A = [1] * (n+1)
    for i in range(1,n):
        for j in range(i,0,-1):
            A[j] += A[j-1]
    return A[k]

Whoa, too fast. Let's comment the code to see what really happens.

def bc_pascal_triangle(n, k):
    # initially, the whole data is all 1s (the edges of the triangle are, remember)
    A = [1] * (n+1)
    
    # repeat for each level of the triangle
    for i in range(1,n):
    
        # calculate the numbers on this level, by
        # a) going backwards (this is crucial: going forwards would overwrite data on level n-1)
        # b) setting each item to the sum of its left-door neighbour and this item on the level n-1
        for j in range(i,0,-1):
            A[j] += A[j-1]
    return A[k]

Sounds too complicated? Take a look at the output:

i= 0 [1,  1,  1,   1,   1,   1,   1,   1,  1,  1, 1]
i= 1 [1,  2,  1,   1,   1,   1,   1,   1,  1,  1, 1]
i= 2 [1,  3,  3,   1,   1,   1,   1,   1,  1,  1, 1]
i= 3 [1,  4,  6,   4,   1,   1,   1,   1,  1,  1, 1]
i= 4 [1,  5, 10,  10,   5,   1,   1,   1,  1,  1, 1]
i= 5 [1,  6, 15,  20,  15,   6,   1,   1,  1,  1, 1]
i= 6 [1,  7, 21,  35,  35,  21,   7,   1,  1,  1, 1]
i= 7 [1,  8, 28,  56,  70,  56,  28,   8,  1,  1, 1]
i= 8 [1,  9, 36,  84, 126, 126,  84,  36,  9,  1, 1]
i= 9 [1, 10, 45, 120, 210, 252, 210, 120, 45, 10, 1]

Note that in each row, A[i]now = A[i]previous + A[i-1]previous.

This code obviously works, but it can be very slow: basically it is O(N2), compared to the O(N) we're used to with our other algorithms. OK, there is a better way - the algorithm is a lot more complex (and ugly!), but it does the job better.

Putting this in an algorithm results in some ugly code:

def bc_ugly_and_safe(n, k):
    if k > n/2:
        upper_range = range(n,k,-1)
        lower_range = range(2, (n-k+1))
    else:
        upper_range = range(n,n-k,-1)
        lower_range = range(2, k+1)
    
    # go from the top, thus ensuring the biggest possible divisor is used always
    for i in range(len(lower_range)-1,-1,-1):
        divisor = lower_range[i]
        found = False
        
        # go from the top, thus ensuring that dividend product factors are as small as possible
        for k in range(len(upper_range)-1,-1,-1):
            dividend = upper_range[k]
            if dividend % divisor == 0:
                upper_range[k] /= divisor
                lower_range[i] = None
                break

    lower_range = filter(None, lower_range)
    u = 1
    if lower_range:
        for item in upper_range:
            u *= item
            
            # we've just added another factor to the dividend, let's see
            # if we have a divisor for that ready
            for i in range(len(lower_range)-1,-1,-1):
                divisor = lower_range[i]
                if u % divisor == 0:
                    u /= divisor
                    del lower_range[i]
                    break

        # why this assertion?
        # the definition of the binominal coefficient results in integer numbers.
        # if there were divisors left, then u/l would be a fraction. q.e.d.
        assert not lower_range
    else:    
        for item in upper_range:
            u *= item
    
    return u

If you put in monitor code to visit the highest integer numbers seen in both this code an in the pascal-triangle version, you'll note that this code uses often smaller numbers, and at worst the same numbers. So it is indeed the better algorithm. But does it perform?

         using pascal triangle: took 11.55 for result 68673876895888708995L
           using ugly and safe: took  0.00 for result 68673876895888708995L
             using brute force: took  0.08 for result 68673876895888708995L
     using dynamic programming: took  0.04 for result 68673876895888708995L
          using optimized code: took  0.00 for result 68673876895888708995L

That wasn't even close :)