from z3 import * # some hash functions def xor_hash(s): h = 0 for i in range(0,len(s)): h = h ^ ord(s[i]) return h def rot_hash(s): h = 0 for i in range(0,len(s)): h = (h << 4) ^ (h >> 28) ^ ord(s[i]) return h def sax_hash(s): h = 0 for i in range(0,len(s)): h = h ^ ((h << 5) + (h >> 2) + ord(s[i])) return h # encoding stuff def Iff(a,b): return And(Implies(a,b),Implies(b,a)) def zext(bits, num_bits): for i in range(0,num_bits - len(bits)): bits.insert(0,False) return bits def mk_char(pre): global varcount vars = [] for i in range(0, 7): vars.insert(0, FreshBool(pre + str(i))) return zext(vars,32) def mk_string(pre,l): s = [] for i in range(0,l): s.append(mk_char(pre)) return s def to_int(bits): bs = map(lambda b: 1 if b else 0, bits) v = 0 for i in range(0,32): v = v + ((1 << (31-i)) * bs[i]) return v def chrstr(i): return chr(i) if i >= 33 and i <= 126 else "ascii("+str(i)+")" def xor_encode(a,b): r = [] for i in range(0, 32): r.append(Xor(a[i], b[i])) return r def add_encode(a,b): r = [] r.append(Xor(a[31], b[31])) carry = And(a[31],b[31]) i = 30 while i >= 0: r.insert(0,simplify(Xor(a[i], Xor(b[i], carry)))) carry = simplify(Or(And(a[i],b[i]), And(a[i], carry), And(b[i],carry))) i = i - 1 return r def eq_encode(a,b): r = Iff(a[0], b[0]) for i in range(1, 32): r = And(r, Iff(a[i], b[i])) return r def eq_string(a,b): if len(a) != len(b): return False if len(a) == 0: return True r = eq_encode(a[0], b[0]) for i in range(1, len(a)): r = And(r, eq_encode(a[i], b[i])) return r def print_bits(c,sol): s = "" for i in range(0,32): s = s + str(1 if sol[c[i]] else 0) + " " print(s) def eval(val,model): return True if val == True else False if val == False else model.evaluate(val) def decode_string(a, model): s = "" ids = [] for j in range(0,len(a)): i = to_int([eval(a[j][i], model) for i in range(0,32)]) s = s + chr(i) ids.append(i) return s,ids def left_rotate(bits, k): return bits[k:] + bits[:k] def shift_left(bits, k): return bits[k:] + [ False for i in range(0,k)] def shift_right(bits, k): return [False for i in range(0,k)] + bits[:32-k] # encodings of hash functions # assume s is non-empty def xor_hash_encode(s): h = s[0] for j in range(1,len(s)): c = s[j] h = xor_encode(h,c) return h def rot_hash_encode(s): h = s[0] for j in range(1,len(s)): h = xor_encode(left_rotate(h,4), s[j]) return h def sax_hash_encode(s): #h = h ^ ((h << 5) + (h >> 2) + ord(s[i])) h = zext([],32) for j in range(0,len(s)): h_shl = shift_left(h,5) h_shr = shift_right(h,2) sum = add_encode(h_shl, add_encode(h_shr,s[j])) h = xor_encode(h,sum) return h # finding collisions def find_hash_collision(name, hash_fun, hash_fun_enc, str_len): a = mk_string("a",str_len) b = mk_string("b",str_len) solver = Solver() solver.add(Not(eq_string(a,b))) solver.add(eq_encode(hash_fun_enc(a), hash_fun_enc(b))) sat = solver.check() print "Hash collision check for " + name if sat: print(" collision found:") m = solver.model() a_str, a_ids = decode_string(a,m) b_str, b_ids = decode_string(b,m) print(" a =" + a_str.encode('string_escape') + \ " consisting of ASCII characters " + str(a_ids) + \ ", hash " + str(hash_fun(a_str))) print(" b =" + b_str.encode('string_escape') + \ " consisting of ASCII characters " + str(b_ids) + \ ", hash " + str(hash_fun(b_str))) print "\n" else: print("No collision exists.") find_hash_collision("Xor hashing", xor_hash, xor_hash_encode, 5) find_hash_collision("rotation hashing", rot_hash, rot_hash_encode, 5) find_hash_collision("Shift-Add-Xor hashing", sax_hash, sax_hash_encode, 5) #MD5 def F_enc(x,y,z): return bor(band(x, y), band(bnot(x), z)) def G_enc(x,y,z): return bor(band(x, y), bor(band(x, z), band(y, z))) def H(x,y,z): return xor_encode(x, xor_encode(y, z))