Sunday 29 January 2012

Solving Causes' Levenshtein Distance challenge in Python, the Sequel


The 3 faithful readers of this blog have probably seen my previous attempt at cracking Causes' Levenshtein distance challenge. It all went well until Adam Derewecki of Causes commented with the following:
...Pretty good solution though, about 15s on our benchmark machine. Record is 11.3s if you're up to the challenge :)
At first I was like "Yeah right, mate, you're not roping -me- in with that one, I have a startup to run." But the predictable engineer's mind just couldn't let it go. How could someone have done about 30% better in Python? What was I missing? So, I started hacking at the code again. Turns out (surprise!) I was missing quite a bit. Let's start with putting the original code up for you to see:

import string
w = set(open("00wordlist.txt").read().splitlines())
f, nf = set(), set(["causes"])

#from b, yield all unused words where levdist==1
def nextgen(b):
    for i in range(len(b)): #for each index in b
        for c in string.ascii_lowercase: #for letters [a..z]
            if c != b[i]:
                #substitute b[i] with c
                if b[:i] + c + b[i+1:] in w:
                    yield b[:i] + c + b[i+1:]
                #inject c before b[i]
                if b[:i] + c + b[i:] in w:
                    yield b[:i] + c + b[i:]
        #remove b[i]
        if b[:i] + b[i+1:] in w: yield b[:i] + b[i+1:]
    
    for c in string.ascii_lowercase: #for letters [a..z]
        if b + c in w: yield b + c #append c after b

while len(nf):
    cf = nf
    nf = set([j for i in cf for j in nextgen(i) if j not in f])
    w -= nf
    f |= nf

print len(f)

First, Adam's suggestion was very good by itself. Why write this:

nf = set([j for i in cf
        for j in nextgen(i) 
            if j not in f])

when you can omit the intermediate array and just write this:

nf = set(j for i in cf 
        for j in nextgen(i) 
            if j not in f)

But it gets better. Since I subtract nf from w, from where the values are sourced, why even check if j not in f? No reason. So, we end up with the much more palatable:

nf = set(j for i in cf for j in nextgen(i))

After improving that line, I noticed that I had a line above that was doing absolutely nothing whatsoever:

cf = nf

This line simply betrays my uncertainty about how python's comprehensions work. It turns out the next line can be simply written as follows, with no need to ever declare cf at all.

nf = set(j for i in nf for j in nextgen(i))

Next up, let's look at the little optimisation I had in line 9:

if c != b[i]:

Here I used a whole line to check that I wasn't going to be doing any useless checks. Even though I was aiming for small code. Even though Python has O(1) membership testing. When I looked again at the code and doubted my own premature optimisation, the results were damning: The test cost more time than it saved. Removing that line yields a speed improvement.

All these improvements were small. They saved 1-2 seconds over the total of 25 seconds it takes on my laptop. The big improvement came when I tried the technique seen in this stackoverflow answer. Interrupting the program while running for a few times indicated the culprit. The constant use of the slicing operation was not doing me any favours. For every given letter and every position in a string I did operations like this:

if b[:i] + c + b[i+1:] in w:
    yield b[:i] + c + b[i+1:]

That's 4 slice operations, and actually this is done twice for a total of 8 per letter. So I decided to do the slicing only once per position, assign the results to variables and use those for each letter. That sped things up enormously. It brought runtime from slightly under 23 to well under 17 seconds.

UPDATE: After some impromptu after-work tinkering with my co-founder Pagan, we realised Python iterates over lists faster than over strings, which means that adding the line

letters = list(string.ascii_lowercase)

to the setup part of the code speeds things up by a cool 4%.

All these improvements add up to 1/3 of the total running time. Since Adam said that my programme ran for 15 seconds on the benchmark machine, while the best Python they had ran at 11.3, I suspect this may be enough to beat the frontrunner. Now I just have to get Adam to test this one again.

Another change I did is improve the horrible variable naming I had last time around, and also add a few more comments. I also was very strict about keeping lines under 65 characters in length. So here is the resulting program:

import string
words = set(open("00wordlist.txt").read().splitlines())
frnds, newfrnds = set(), set(["causes"])
letters = list(string.ascii_lowercase)

#from word wd, yield all unused words where levdist==1
def freefrnds(wd):
    for i in range(len(wd)): #for each index in wd
        wd_upto_i,wd_from_i,wd_after_i = wd[:i],wd[i:],wd[i+1:]
        for char in letters: #for letters [a..z]
            #substitute wd[i] with char
            if wd_upto_i + char + wd_after_i in words:
                yield wd_upto_i + char + wd_after_i
            #inject char before wd[i]
            if wd_upto_i + char + wd_from_i in words:
                yield wd_upto_i + char + wd_from_i
        #remove wd[i] from word
        if wd_upto_i + wd_after_i in words:
            yield wd_upto_i + wd_after_i

    for char in letters: #for letters [a..z]
        #append char after word
        if wd + char in words: yield wd + char

while len(newfrnds):
    newfrnds = set(j for i in newfrnds for j in freefrnds(i))
    frnds |= newfrnds #add newfrnds to the frnds set
    words -= newfrnds #remove list of newfrnds from words

print len(frnds)