from z3 import * # some hash functions def xor_hash(s): h = 0 for i in range(0,len(s)): h = h ^ ord(s[i]) return h def rot_hash(s): h = 0 for i in range(0,len(s)): h = (h << 4) ^ (h >> 28) ^ ord(s[i]) return h def sax_hash(s): h = 0 for i in range(0,len(s)): h = h ^ ((h << 5) + (h >> 2) + ord(s[i])) return h # encoding stuff def Iff(a,b): return And(Implies(a,b),Implies(b,a)) def zext(bits, num_bits): for i in range(0,num_bits - len(bits)): bits.insert(0,False) return bits # string is represented by a sequence of bit vectors of length 8 def mk_string(pre,l): s = [] for i in range(0,l): s.append(BitVec(pre + str(i),8)) return s # formula to express 2 strings being equal (i.e., all bit vectors in list equal) def eq_string(a,b): if len(a) != len(b): return False if len(a) == 0: return True r = a[0] == b[0] for i in range(1, len(a)): r = And(r, a[i] == b[i]) return r # evaluate all bitvectors of chars in a string (to print result) def decode_string(a, model): s = "" ids = [] for j in range(0,len(a)): i = model[a[j]].as_long() s = s + chr(i) ids.append(i) return s,ids # encoding of hash functions def sax_hash_encode(s): #h = h ^ ((h << 5) + (h >> 2) + ord(s[i])) h = BitVecVal(0,64) for j in range(0,len(s)): h_shl = h << BitVecVal(5,64) h_shr = LShR(h, BitVecVal(2,64)) sum = h_shl + h_shr + ZeroExt(56,s[j]) h = h ^ sum return h # finding collisions def find_hash_collision(name, hash_fun, hash_fun_enc, str_len): # create 2 strings (i.e., lists of bit vector variables) a = mk_string("a",str_len) b = mk_string("b",str_len) solver = Solver() # we want the strings to be different solver.add(Not(eq_string(a,b))) # but their hash value should be the same solver.add(hash_fun_enc(a) == hash_fun_enc(b)) sat = solver.check() # print result print "Hash collision check for " + name if sat: print(" collision found:") m = solver.model() a_str, a_ids = decode_string(a,m) b_str, b_ids = decode_string(b,m) print(" a =" + a_str.encode('string_escape') + \ " consisting of ASCII characters " + str(a_ids) + \ ", hash " + str(hash_fun(a_str))) print(" b =" + b_str.encode('string_escape') + \ " consisting of ASCII characters " + str(b_ids) + \ ", hash " + str(hash_fun(b_str))) print "\n" else: print("No collision exists.") # find a hash collision for shift-add-xor hashing with strings of length 8 find_hash_collision("Shift-Add-Xor hashing", sax_hash, sax_hash_encode, 8)