0CTF Quals 2019: "zer0lfsr" writeup

This challenge asked us to recover the initial state of three linear-feedback shift registers (each of size 48), given the output of a nonlinear combine() function applied to the three individual output streams:

def combine(x1,x2,x3):
    return (x1*x2)^(x2*x3)^(x1*x3)

The given challenge source code consists of a fairly typical lfsr class, along with some code that writes the generated stream to a file:

from secret import init1,init2,init3,FLAG
import hashlib
assert(FLAG=="flag{"+hashlib.sha256(init1+init2+init3).hexdigest()+"}")

class lfsr():
    def __init__(self, init, mask, length):
        self.init = init
        self.mask = mask
        self.lengthmask = 2**(length+1)-1

    def next(self):
        nextdata = (self.init << 1) & self.lengthmask 
        i = self.init & self.mask & self.lengthmask 
        output = 0
        while i != 0:
            output ^= (i & 1)
            i = i >> 1
        nextdata ^= output
        self.init = nextdata
        return output

def combine(x1,x2,x3):
    return (x1*x2)^(x2*x3)^(x1*x3)

if __name__=="__main__":
    l1 = lfsr(int.from_bytes(init1,"big"),0b100000000000000000000000010000000000000000000000,48)
    l2 = lfsr(int.from_bytes(init2,"big"),0b100000000000000000000000000000000010000000000000,48)
    l3 = lfsr(int.from_bytes(init3,"big"),0b100000100000000000000000000000000000000000000000,48)

    with open("keystream","wb") as f:
        for i in range(8192):
            b = 0
            for j in range(8):
                b = (b<<1)+combine(l1.next(),l2.next(),l3.next())
            f.write(chr(b).encode())

While one could try to understand the structure of these computations and find a mathematical way to recover the secret initial states, we went for a less sophisticated approach: Observing that growing prefixes of the output bits depend on more and more bits of init1 through init3 (and short prefixes depend on very few secret bits), we suspected that setting up the whole problem as a system of equations and simply throwing a generic SMT solver at it would probably work.

There is one little caveat that we stumbled over: The keystream file does not contain the actual output bytes, but the UTF8-encoded version of the output bytes interpreted as Unicode codepoints. That is, if the output value was 0xcc, then what the script above writes to the keystream file are the bytes c3 8c.

After fixing this small slip-up, we ended up with the following script using z3’s awesome Python bindings:

#!/usr/bin/env python3
import z3, hashlib

class lfsr():
    def __init__(self, init, mask, length):
        self.init = init
        self.mask = mask
        self.lengthmask = 2**(length+1)-1

    def next(self):
        nextdata = (self.init << 1) & self.lengthmask
        i = self.init & self.mask & self.lengthmask
        output = 0
        for j in range(self.lengthmask.bit_length() + 42):
            output ^= z3.LShR(i,j) & 1
        self.init = nextdata ^ output
        return output

def combine(x1,x2,x3):
    return (x1*x2)^(x2*x3)^(x1*x3)

keystream = [ord(x) for x in open('keystream','r').read()]

inits = [z3.BitVec('init{}'.format(i), 48) for i in (1,2,3)]

l1 = lfsr(inits[0],0b100000000000000000000000010000000000000000000000,48)
l2 = lfsr(inits[1],0b100000000000000000000000000000000010000000000000,48)
l3 = lfsr(inits[2],0b100000100000000000000000000000000000000000000000,48)

solver = z3.Solver()
for i,b in enumerate(keystream[:23]):
    for j in reversed(range(8)):
        solver.add(((b >> j) & 1) == combine(l1.next(),l2.next(),l3.next()))

print(solver.check())
m = solver.model()
print(m)

inits = [m[i].as_long() for i in inits]
sol = b''.join(i.to_bytes(6,'big') for i in inits)
flag = 'flag{' + hashlib.sha256(sol).hexdigest() + '}'
print('-->', flag)

It takes less than 10 seconds to recover the flag:

$ ./pwn.py
sat
[init3 = 191532558614761,
 init2 = 181037482648735,
 init1 = 70989122156399]
--> flag{b527e2621131134ec22250cfbca75e8c9f5ae4f40370871fd55910927f66a1b4}
$