Calculating binominal coefficients
NOTE: This article requires MathML, which is not supported by IE. Use Firefox or Safari to visit this page...
In math, the binominal coefficient is defined as
Find an efficient way to compute this.
Ask yourself: Do I understand the problem?
- Input is two numbers:
n
andk
, with n >= k (because otherwise (n-k)! is not defined). - n! is a mathematical expression for the factorial function n! = 1 * 2 * ... * n
- Output is a single integer value of this particular binominal coefficient.
Setting up a test harness
As usual, our first step is to prepare a testbed. Our code strategy for the test will be this:
- We will implement a couple of functions for calculating the binominal coefficient. As usual, we are going to store them in a dictionary along with their name, so that we can call them, time their performance and print out results
- We will verify that all algorithms return the same output for the same input.
Therefor, our test code looks like this:
#! -*- Encoding: Latin-1 -*-
def selftest():
import time
methods = { 'using brute force' : bc_brute_force,
'using dynamic programming' : bc_dynamic_programming,
... }
expected_result = None
for method_name in methods:
method = methods[method_name]
t0 = time.time()
result = 0
for n in range(1000,1010):
for k in range(5,10):
result ^= method(n, k)
print "%30s: took %.2f for result %r" % (method_name, time.time() - t0, result, )
if expected_result is None:
expected_result = result
else:
assert result == expected_result
if __name__ == "__main__":
selftest()
This code doesn't compare a single result, it xors all results and compares those. But that is just a small side show for the main algorithm.
Brute-Force approach
Let's try brute-force first: the mathematical formula at the top is pretty straightforward: calculate three factorials, and then evaluate the term
def fac(n):
result = 1
for i in xrange(2, n+1):
result *= i
return result
def bc_brute_force(n, k):
return fac(n) / (fac(k) * fac(n-k))
This kind of begs the question: what is the fastest factorial function here? Let's briefly look at three alternatives:
def fac1(n):
result = 1
for i in range(2, n+1):
result *= i
return result
def fac2(n):
return reduce(operator.mul, xrange(1, n+1))
def fac3(n):
result = 1
i = n
while i >= 2:
result *= i
i -= 1
return result
# quick testbed: evaluating the best performance of the three factorial functions
for func in (fac1, fac2, fac3, ):
t0 = time.time()
func(100000)
print func, time.time() - t0
Quick: guess which one is the fastest? I for one was surprised:
<function fac1 at 0x01F2D230> 11.4319999218 <function fac2 at 0x0237B9F0> 11.4539999962 <function fac3 at 0x0237B8F0> 9.6970000267
It turns out that the first functions, while being pythonic and all, allocate too much memory for their own good. So fac3 is what we're going to use for slightly optimized brute-force:
def fac(n):
result = 1
i = n
while i >= 2:
result *= i
i -= 1
return result
def bc_brute_force(n, k):
return fac(n) / (fac(k) * fac(n-k))
Some notes on this implementation:
- The time complexity of this solution is O(N) (it is actually O(N) for each of the three factorials, but the rules of Big-Oh notation warp O(3N) to O(N)...)
- The space complexity of this solution is O(1). Hard to beat.
Using Dynamic Programming
Well, if you look into textbooks, then our factorial function is already using dynamic programming. But we can improve on that, noticing that the original definition asks us to calculate three things:
- n!
- k! with k < n
- (n-k)! with (n-k) < n
That means that as we "visit" each number to calculate n!, we might as well calculate k! and (n-k)! along with it, thus making one call out of three.
def bc_dynamic_programming(n, k):
n_fak = 1
k_fak = None
n_minus_k_fak = None
n_minus_k = n - k
for x in range(1, n+1):
n_fak *= x
if x == k:
k_fak = n_fak
if x == n_minus_k:
n_minus_k_fak = n_fak
return n_fak / (k_fak * n_minus_k_fak)
Next stop: algebra
Take a look at the definition of the factorial again.
>You see that you actually can save a couple of multiplications this way. Now, there is an alternative way to optimize the expression:
Think about the question: when does it make sense to use either representation?
It turns out that if k > n/2
, then the term 1 * 2 * ... * k
is longer
than the term 1 * 2 * ... * (n-k)
, so you would want to use the first representation
in that case. And if k < n/2
, then the second expression wins.
Armed with that, let's try to write code for it:
def bc_optim(n, k):
if k > n/2:
upper_range = range(k+1,n+1)
lower_range = range(2, (n-k+1))
else:
upper_range = range(n-k+1,n+1)
lower_range = range(2, k+1)
u = 1
for dividend in upper_range:
u *= dividend
l = 1
for divisor in lower_range:
l *= divisor
return u/l
So, was all that math worth it? Let's look at our testbed and compare the results:
using brute force: took 1.56 for result 4423137553666283105672019847538510671575080839179707002338085191579938072287668281631536191437946606118381203890479389237995287646214963339276827695049847636402807681032197306196418310999689640422676267452377523863414743729520994673266549091320284038814420240634986799913936602081779632146004368166601L using dynamic programming: took 0.92 for result 4423137553666283105672019847538510671575080839179707002338085191579938072287668281631536191437946606118381203890479389237995287646214963339276827695049847636402807681032197306196418310999689640422676267452377523863414743729520994673266549091320284038814420240634986799913936602081779632146004368166601L using optimized code: took 0.43 for result 4423137553666283105672019847538510671575080839179707002338085191579938072287668281631536191437946606118381203890479389237995287646214963339276827695049847636402807681032197306196418310999689640422676267452377523863414743729520994673266549091320284038814420240634986799913936602081779632146004368166601L
So it turns out that that last function indeed is a significant improvement over the first - even though by the rules of Big-Oh, nothing has changed, really: complexity is still O(N), but the constant factors are much better now...
Caring about overflow: AKA Pascals Triangle
Now, in the very fine Elements of Programming Interviews problem 12.14 asks to design an efficient algorithm for binonminal coefficients - we did that - but one that doesn't overflow.
Well duh. Python doesn't overflow: integers are implemented in a bignum library and cannot overflow like that. But there are poor people out there forced to write Java code, or C/C++ code (well C/C++ has its merits, see here), and there integers are limited to 32 or 64 bit. So how can we help those poor chaps and gals?
The reference solution in the book works basically like this:
- Calculate the binominal coefficient using the pascal triangle, and since we assume that (N over K) can be represented in 32-bit, then so can also the coefficients in the triangle
OK, let's calculate the coefficient using pascals triangle. The code is actually very simple - much more simple than our previous suggestions:
def bc_pascal_triangle(n, k):
A = [1] * (n+1)
for i in range(1,n):
for j in range(i,0,-1):
A[j] += A[j-1]
return A[k]
Whoa, too fast. Let's comment the code to see what really happens.
def bc_pascal_triangle(n, k):
# initially, the whole data is all 1s (the edges of the triangle are, remember)
A = [1] * (n+1)
# repeat for each level of the triangle
for i in range(1,n):
# calculate the numbers on this level, by
# a) going backwards (this is crucial: going forwards would overwrite data on level n-1)
# b) setting each item to the sum of its left-door neighbour and this item on the level n-1
for j in range(i,0,-1):
A[j] += A[j-1]
return A[k]
Sounds too complicated? Take a look at the output:
i= 0 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
i= 1 [1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1]
i= 2 [1, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1]
i= 3 [1, 4, 6, 4, 1, 1, 1, 1, 1, 1, 1]
i= 4 [1, 5, 10, 10, 5, 1, 1, 1, 1, 1, 1]
i= 5 [1, 6, 15, 20, 15, 6, 1, 1, 1, 1, 1]
i= 6 [1, 7, 21, 35, 35, 21, 7, 1, 1, 1, 1]
i= 7 [1, 8, 28, 56, 70, 56, 28, 8, 1, 1, 1]
i= 8 [1, 9, 36, 84, 126, 126, 84, 36, 9, 1, 1]
i= 9 [1, 10, 45, 120, 210, 252, 210, 120, 45, 10, 1]
Note that in each row, A[i]now = A[i]previous + A[i-1]previous.
This code obviously works, but it can be very slow: basically it is O(N2), compared to the O(N) we're used to with our other algorithms. OK, there is a better way - the algorithm is a lot more complex (and ugly!), but it does the job better.
- Take note that effectively we are dividing multiplication products.
- So how about trying to find an optimal dividend for each divisor as we go along?
- This way, we ensure that the fraction is as much reduced as possible.
- We will have to take care for special situations: sometimes, individual product factors will not be divisible by individual dividends, but maybe the sum of two product factors will
Putting this in an algorithm results in some ugly code:
def bc_ugly_and_safe(n, k):
if k > n/2:
upper_range = range(n,k,-1)
lower_range = range(2, (n-k+1))
else:
upper_range = range(n,n-k,-1)
lower_range = range(2, k+1)
# go from the top, thus ensuring the biggest possible divisor is used always
for i in range(len(lower_range)-1,-1,-1):
divisor = lower_range[i]
found = False
# go from the top, thus ensuring that dividend product factors are as small as possible
for k in range(len(upper_range)-1,-1,-1):
dividend = upper_range[k]
if dividend % divisor == 0:
upper_range[k] /= divisor
lower_range[i] = None
break
lower_range = filter(None, lower_range)
u = 1
if lower_range:
for item in upper_range:
u *= item
# we've just added another factor to the dividend, let's see
# if we have a divisor for that ready
for i in range(len(lower_range)-1,-1,-1):
divisor = lower_range[i]
if u % divisor == 0:
u /= divisor
del lower_range[i]
break
# why this assertion?
# the definition of the binominal coefficient results in integer numbers.
# if there were divisors left, then u/l would be a fraction. q.e.d.
assert not lower_range
else:
for item in upper_range:
u *= item
return u
If you put in monitor code to visit the highest integer numbers seen in both this code an in the pascal-triangle version, you'll note that this code uses often smaller numbers, and at worst the same numbers. So it is indeed the better algorithm. But does it perform?
using pascal triangle: took 11.55 for result 68673876895888708995L using ugly and safe: took 0.00 for result 68673876895888708995L using brute force: took 0.08 for result 68673876895888708995L using dynamic programming: took 0.04 for result 68673876895888708995L using optimized code: took 0.00 for result 68673876895888708995L
That wasn't even close :)