""" $Id: break-xmlenc-2.py $ $ Date: 2020-01-02 09:37 $ $ Version: 1.0.0 $ Copyright (C) 2020-21 David Ireland, DI Management Services Pty Ltd <https://di-mgt.com.au> SPDX-License-Identifier: MIT [**] This has the simplifications from the paper: 1. The plaintext does not contain any "Type-A" character except for (possibly) the "<" character (so no entity references like >) 2. Each encrypted block contains only incomplete elements (i.e. there exists no start tag followed by element content and an end tag). [***] Additional simplifications in this program 3. Restriction (2) is ignored by our XML "parser". It will accept any XML elements, complete or not. See the code for `oracle()`. """ # pylint: disable=unused-wildcard-import from cryptosyspki import * # Debugging/logging import logging logging.basicConfig(level=logging.ERROR) DEBUG = False dprint = __builtins__.print if DEBUG else logging.debug dprint("PKI version =", Gen.version()) # GLOBAL VARIABLES n = 16 # Block length in bytes of AES NCHARS = 128 # Number of ASCII characters key = Cnv.fromhex("0123456789ABCDEFF0E1D2C3B4A59687") # Set of Type-A characters: All from 0x00 to 0x1F except 0x09 (TAB), 0x0A (LF) and 0x0D (CR) # plus 0x26 ('<') and 0x3C ('&'). type_a = list(range(0x00, 0x08 + 1)) + [0x0b, 0x0c, 0x0e, 0x0f] + list(range(0x10, 0x1f + 1)) + [0x26, 0x3c] found_typea = [] # Global storage for indices to Type-A characters found in plaintext def show_byteshex(a): return ["0x{:02x}".format(x) for x in a] def Dec(iv, c): """ Decrypt ciphertext block c using global AES-128 key and given IV""" dt = Cipher.decrypt_block(c, key, iv, Cipher.Alg.AES128, Cipher.Mode.CBC) return dt def oracle(iv, c): """ O(C) = 1, if the server returns a 'security fault' O(C) = 0 otherwise. Simplified, based on assumptions [**] above. """ reply = 0 # Presume innocent until proven guilty m = list(Dec(iv, c)) # Is the padding byte valid? pad = m[n - 1] if pad < 0x1 or pad > 0x10: # print("Invalid padding byte") return 1 # Are there any type A characters after stripping padding? if any(x in type_a for x in m[0:n - pad]): # print("Found Type A") reply = 1 # (simplified) parsing error return reply def get_valid_padding_masks(iv, c): pset = [] for j in range(0, 0x7F + 1): iv1 = bytearray(iv) iv1[n - 1] = iv1[n - 1] ^ j if oracle(iv1, c) == 0: pset.append(iv1[n - 1]) # Pset \union IV'_n return pset def get_iv_with_padding_mask_01(pset, IV): assert (len(pset) == n) dprint("Pset =", ["0x{:02x}".format(x) for x in pset]) # GetIvWithPaddingMask01 # padding masks Pset = {msk0x01, msk0x02, ..., msk0x10} # msk0x10 differs from others in the 4th bit list4thbit = [((x & 0x10) >> 4) for x in pset] dprint("4thbit =", list4thbit) # List of indices to items of value 1 idx1 = [idx for idx, val in enumerate(list4thbit) if val != 0] dprint("Indices of '1': ", idx1) # List of indices to items of value 0 idx0 = [idx for idx, val in enumerate(list4thbit) if val == 0] dprint("Indices of '0': ", idx0) # One of these lists should contain exactly one element - this is msk0x10 if len(idx1) == 1: msk0x10 = pset[idx1[0]] elif len(idx0) == 1: msk0x10 = pset[idx0[0]] else: assert (1 == 0) dprint("msk0x10 =", hex(msk0x10)) iv = bytearray(IV) # NB *not* XOR'd with original IV iv[n - 1] = msk0x10 ^ 0x11 return iv def find_iv(IV, c): """Input: A ciphertext C = (C(i-1), C(i)) Output: iv that is well-formed.""" global found_typea iv = bytearray(IV) pset = [] for i in range(1, 100): # repeat... pset = get_valid_padding_masks(iv, c) pos = len(pset) dprint("|Pset| =", pos) assert (0 < pos <= n) if pos == n: break # if pos < 16 then we have a '<' at position pos, so # flip last bit of byte in position pos (1, 2, ..., 15) iv[pos - 1] ^= 0x01 # and save its index because we've already decrypted it found_typea.append(pos - 1) # repeat until... iv = get_iv_with_padding_mask_01(pset, iv) return iv def ComputeSetAset(iv, c, j): """Input: C = (iv, c), j in {0,...,n-1) Output: Set Aset of masks """ # print("At start of ComputeSetAset iv =", show_byteshex(iv)) aset = [] for R in range(0, 7 + 1): msk = (R << 4) & 0xFF # 0xR0 # XOR the jth byte of iv with msk iv1 = bytearray(iv) # Caution: Python shallow copying! iv1[j] ^= msk # Test oracle for this new iv' if oracle(iv1, c) == 1: aset.append(msk) return aset def FindXByte(c, iv, IV0, j): """Input: c Single-block ciphertext iv such that C = (iv, c) is well-formed IV0 original IV j index in range [0, n) Output: j-th byte of x = Dec(k, c) """ x_j = 0 dprint(f"Calling FindXByte for j = {j}") # Have we already detected a Type-A character here? # -- in this simplified case, it is always a '<' if j in found_typea: dprint("Already found: '<'") x_j = ord('<') ^ IV0[j] return x_j dprint("iv =", show_byteshex(iv)) if j == n - 1: # special case for nth byte x_j = 0x01 ^ iv[n - 1] dprint("Special case for last byte x(n-1) =", hex(x_j)) return x_j aset = ComputeSetAset(iv, c, j) len_aset = len(aset) dprint(f"|Aset| = {len_aset}") # Only three cases possible if len_aset == 1: dprint("Case 1") # Last 4 bits are equal to 0x?9, 0x?A or 0x?D msk = aset[0] dprint("msk =", hex(msk)) # There is exactly one msk' in {0x25, 0x26, 0x21} such that # m_j XOR msk XOR msk' = 0x3C is a Type-A character for msk1 in [0x25, 0x26, 0x21]: # print("msk' =", hex(msk1)) iv1 = bytearray(iv) iv1[j] ^= msk ^ msk1 if oracle(iv1, c) == 1: x_j = 0x3C ^ iv1[j] # NB no xoring with masks elif len_aset == 2: dprint("Case 2") # Last 4 bits are equal to 0x?0, 0x?1, 0x?2, 0x?3, 0x?4, 0x?5, 0x?7, 0x?8, 0x?b, 0x?e, 0x?f for msk in [aset[0]]: # There are 11 potential masks msk' for each msk in Aset, but symmetrical (0x20 vs 0x30) mset = [] for msk1 in [0x2c, 0x2d, 0x2e, 0x2f, 0x28, 0x29, 0x2b, 0x24, 0x27, 0x22, 0x23, 0x3c, 0x3d, 0x3e, 0x3f, 0x38, 0x39, 0x3b, 0x34, 0x37, 0x32, 0x33]: iv1 = bytearray(iv) iv1[j] ^= msk ^ msk1 if oracle(iv1, c) == 1: dprint("Found a Type-A match for msk'=", hex(msk1)) mset.append(msk1) dprint("mset =", show_byteshex(mset)) if len(mset) == 1: msk1 = mset[0] x_j = 0x3C ^ iv[j] ^ msk ^ msk1 else: # We should have two msk1 = mset[0] iv1 = bytearray(iv) iv1[j] ^= msk ^ msk1 ^ 0x31 if oracle(iv1, c) == 1: # Not this one, so the other msk1 = mset[1] x_j = 0x3C ^ iv[j] ^ msk ^ msk1 elif len_aset == 3: dprint("Case 3") # Last 4 bits are equal to 0x?6, 0x?C dprint("Aset =", show_byteshex(aset)) x_j = 0x00 for msk in aset: mset = [] # Out of the 6 combinations we expect exactly one Type-A outcome for msk1 in [0x31, 0x2f]: iv1 = bytearray(iv) iv1[j] ^= msk ^ msk1 if oracle(iv1, c) == 1: dprint("Found a Type-A match for msk'=", hex(msk1)) mset.append(msk1) # If we found a Type-A if msk1 == 0x31: # "&" x_j = 0x26 ^ iv[j] ^ msk else: # "<" x_j = 0x3C ^ iv[j] ^ msk else: # Should not happen assert (1 == 0) return x_j def break_block(IV, c): dprint("break_block CT:", ["0x{:02x}".format(x) for x in c]) global found_typea found_typea = [] m = bytearray(n) iv = find_iv(IV, c) dprint("At start, IV is = ", ["0x{:02x}".format(x) for x in IV]) dprint("FindIV returns iv =", ["0x{:02x}".format(x) for x in iv]) # for bytes j = 1 to n in c do x_j = FindXByte(C(i), iv, j) for j in range(0, n): x_j = FindXByte(c, iv, IV, j) m_j = x_j ^ IV[j] m[j] = m_j return m def debug_block(msg_blocks, m1, msg1, i): dprint("m1=", m1) dprint("m1=", show_byteshex(m1)) dprint("OK=", show_byteshex(msg1)) dprint(" ", [" " + chr(x) + " " for x in msg1]) dprint(" ", [(" ok " if x == y else " ** ") for x, y in zip(m1, msg1)]) nok = sum([x == y for x, y in zip(m1, msg1)]) print(f"{nok} correct out of {n}") return n - nok # Number of errors def main(): # SET UP # |--------------|---------------|---------------|---------------| # 1234567890123456789012345678901234567890123456789012345678901234 msg = "Now <Is>|the <Lime for> all good men to come to the aid of their" iv = Cnv.fromhex("FEDCBA9876543210FEDCBA9876543210") # key is a global variable available to the Oracle (but no peeking!) print("INPUT...") print(f"MSG='{msg}'") print("KY=", Cnv.tohex(key)) print("IV=", Cnv.tohex(iv)) print("PT=", Cnv.tohex(msg.encode())) # split into blocks - require exact multiple of block size msg_blocks = [msg[i:i + n] for i in range(0, len(msg), n)] ct = Cipher.encrypt_block(msg.encode(), key, iv, Cipher.Alg.AES128, Cipher.Mode.CBC) print("CT=", Cnv.tohex(ct)) # array to accept decrypted output mout = [] # split up ciphertext into blocks ct_blocks = [ct[i:i + n] for i in range(0, len(ct), n)] nblocks = len(ct_blocks) print(f"Found {nblocks} blocks") # Break the first block ct1 = ct_blocks[0] dprint("CT1=", Cnv.tohex(ct1)) dprint("BLOCK 1...") m1 = break_block(iv, ct1) mout.append(m1) # Break subsequent blocks... for i in range(1, nblocks): dprint(f"\nBLOCK {i + 1}...") m1 = break_block(ct_blocks[i - 1], ct_blocks[i]) mout.append(m1) print("FINAL SOLUTION:") print("OUT='" + ''.join([msg.decode() for msg in mout]) + "'") totalerrs = 0 for i in range(0, nblocks): print("BLOCK:", i + 1) m1 = mout[i] msg1 = msg_blocks[i].encode() nerrs = debug_block(msg_blocks, m1, msg1, i) totalerrs += nerrs print(f"Found {totalerrs} errors.") if __name__ == "__main__": main()