from z3 import * BITWIDTH = 64 def sax_hash(s): h = 0 for i in range(0,len(s)): h = h ^ ((h << 5) + (h >> 2) + ord(s[i])) return h # encoding stuff def Iff(a,b): return And(Implies(a,b),Implies(b,a)) def zext(bits, num_bits): for i in range(0,num_bits - len(bits)): bits.insert(0,False) return bits # string is represented by a sequence of bit vectors of length 8 def mk_string(pre,l): s = [] for i in range(0,l): s.append(BitVec(pre + str(i),8)) return s # formula to express 2 strings being equal (i.e., all bit vectors in list equal) def eq_string(a,b): if len(a) != len(b): return False if len(a) == 0: return True return And([a[i] == b[i] for i in range(0, len(a))]) # evaluate all bitvectors of chars in a string (to print result) def decode_string(a, model): s = "" ids = [] for j in range(0,len(a)): i = model[a[j]].as_long() s = s + chr(i) ids.append(i) return s,ids # encoding of string s using shift-add-xor hash def sax_hash_encode(s): #h = h ^ ((h << 5) + (h >> 2) + ord(s[i])) h = BitVecVal(0,BITWIDTH) for j in range(0,len(s)): h_shl = h << BitVecVal(5,BITWIDTH) h_shr = LShR(h, BitVecVal(2,BITWIDTH)) sum = h_shl + h_shr + ZeroExt(BITWIDTH - 8,s[j]) h = h ^ sum return h # finding collisions def find_hash_collision(str_len): # create 2 strings (i.e., lists of bit vector variables) a = mk_string("a",str_len) b = mk_string("b",str_len) solver = Solver() # we want the strings to be different solver.add(Not(eq_string(a,b))) # but their hash value should be the same a_encoded = sax_hash_encode(a) b_encoded = sax_hash_encode(b) solver.add(a_encoded == b_encoded) result = solver.check() # print result print("SAX hash collision check") if result == sat: print(" collision found:") m = solver.model() a_str, a_ids = decode_string(a,m) b_str, b_ids = decode_string(b,m) print(" a =" + str(a_str.encode('unicode_escape')) + \ " consisting of ASCII characters " + str(a_ids) + \ ", hash " + str(m.eval(a_encoded))) print(" b =" + str(b_str.encode('unicode_escape')) + \ " consisting of ASCII characters " + str(b_ids) + \ ", hash " + str(m.eval(b_encoded))) print("\n") else: print("No collision exists.") # find a hash collision for shift-add-xor hashing with strings of length 8 find_hash_collision(13)