"""
$Id: break-xmlenc-2.py $
$ Date: 2020-01-02 09:37 $
$ Version: 1.0.0 $
Copyright (C) 2020-21 David Ireland, DI Management Services Pty Ltd
<https://di-mgt.com.au>
SPDX-License-Identifier: MIT
[**] This has the simplifications from the paper:
1. The plaintext does not contain any "Type-A" character
except for (possibly) the "<" character (so no entity references like >)
2. Each encrypted block contains only incomplete elements
(i.e. there exists no start tag followed by element content and an end tag).
[***] Additional simplifications in this program
3. Restriction (2) is ignored by our XML "parser".
It will accept any XML elements, complete or not. See the code for `oracle()`.
"""
# pylint: disable=unused-wildcard-import
from cryptosyspki import *
# Debugging/logging
import logging
logging.basicConfig(level=logging.ERROR)
DEBUG = False
dprint = __builtins__.print if DEBUG else logging.debug
dprint("PKI version =", Gen.version())
# GLOBAL VARIABLES
n = 16 # Block length in bytes of AES
NCHARS = 128 # Number of ASCII characters
key = Cnv.fromhex("0123456789ABCDEFF0E1D2C3B4A59687")
# Set of Type-A characters: All from 0x00 to 0x1F except 0x09 (TAB), 0x0A (LF) and 0x0D (CR)
# plus 0x26 ('<') and 0x3C ('&').
type_a = list(range(0x00, 0x08 + 1)) + [0x0b, 0x0c, 0x0e, 0x0f] + list(range(0x10, 0x1f + 1)) + [0x26, 0x3c]
found_typea = [] # Global storage for indices to Type-A characters found in plaintext
def show_byteshex(a):
return ["0x{:02x}".format(x) for x in a]
def Dec(iv, c):
""" Decrypt ciphertext block c using global AES-128 key and given IV"""
dt = Cipher.decrypt_block(c, key, iv, Cipher.Alg.AES128, Cipher.Mode.CBC)
return dt
def oracle(iv, c):
"""
O(C) = 1, if the server returns a 'security fault'
O(C) = 0 otherwise.
Simplified, based on assumptions [**] above.
"""
reply = 0 # Presume innocent until proven guilty
m = list(Dec(iv, c))
# Is the padding byte valid?
pad = m[n - 1]
if pad < 0x1 or pad > 0x10:
# print("Invalid padding byte")
return 1
# Are there any type A characters after stripping padding?
if any(x in type_a for x in m[0:n - pad]):
# print("Found Type A")
reply = 1 # (simplified) parsing error
return reply
def get_valid_padding_masks(iv, c):
pset = []
for j in range(0, 0x7F + 1):
iv1 = bytearray(iv)
iv1[n - 1] = iv1[n - 1] ^ j
if oracle(iv1, c) == 0:
pset.append(iv1[n - 1]) # Pset \union IV'_n
return pset
def get_iv_with_padding_mask_01(pset, IV):
assert (len(pset) == n)
dprint("Pset =", ["0x{:02x}".format(x) for x in pset])
# GetIvWithPaddingMask01
# padding masks Pset = {msk0x01, msk0x02, ..., msk0x10}
# msk0x10 differs from others in the 4th bit
list4thbit = [((x & 0x10) >> 4) for x in pset]
dprint("4thbit =", list4thbit)
# List of indices to items of value 1
idx1 = [idx for idx, val in enumerate(list4thbit) if val != 0]
dprint("Indices of '1': ", idx1)
# List of indices to items of value 0
idx0 = [idx for idx, val in enumerate(list4thbit) if val == 0]
dprint("Indices of '0': ", idx0)
# One of these lists should contain exactly one element - this is msk0x10
if len(idx1) == 1:
msk0x10 = pset[idx1[0]]
elif len(idx0) == 1:
msk0x10 = pset[idx0[0]]
else:
assert (1 == 0)
dprint("msk0x10 =", hex(msk0x10))
iv = bytearray(IV)
# NB *not* XOR'd with original IV
iv[n - 1] = msk0x10 ^ 0x11
return iv
def find_iv(IV, c):
"""Input: A ciphertext C = (C(i-1), C(i))
Output: iv that is well-formed."""
global found_typea
iv = bytearray(IV)
pset = []
for i in range(1, 100): # repeat...
pset = get_valid_padding_masks(iv, c)
pos = len(pset)
dprint("|Pset| =", pos)
assert (0 < pos <= n)
if pos == n:
break
# if pos < 16 then we have a '<' at position pos, so
# flip last bit of byte in position pos (1, 2, ..., 15)
iv[pos - 1] ^= 0x01
# and save its index because we've already decrypted it
found_typea.append(pos - 1)
# repeat until...
iv = get_iv_with_padding_mask_01(pset, iv)
return iv
def ComputeSetAset(iv, c, j):
"""Input: C = (iv, c), j in {0,...,n-1)
Output: Set Aset of masks
"""
# print("At start of ComputeSetAset iv =", show_byteshex(iv))
aset = []
for R in range(0, 7 + 1):
msk = (R << 4) & 0xFF # 0xR0
# XOR the jth byte of iv with msk
iv1 = bytearray(iv) # Caution: Python shallow copying!
iv1[j] ^= msk
# Test oracle for this new iv'
if oracle(iv1, c) == 1:
aset.append(msk)
return aset
def FindXByte(c, iv, IV0, j):
"""Input:
c Single-block ciphertext
iv such that C = (iv, c) is well-formed
IV0 original IV
j index in range [0, n)
Output: j-th byte of x = Dec(k, c)
"""
x_j = 0
dprint(f"Calling FindXByte for j = {j}")
# Have we already detected a Type-A character here?
# -- in this simplified case, it is always a '<'
if j in found_typea:
dprint("Already found: '<'")
x_j = ord('<') ^ IV0[j]
return x_j
dprint("iv =", show_byteshex(iv))
if j == n - 1:
# special case for nth byte
x_j = 0x01 ^ iv[n - 1]
dprint("Special case for last byte x(n-1) =", hex(x_j))
return x_j
aset = ComputeSetAset(iv, c, j)
len_aset = len(aset)
dprint(f"|Aset| = {len_aset}")
# Only three cases possible
if len_aset == 1:
dprint("Case 1")
# Last 4 bits are equal to 0x?9, 0x?A or 0x?D
msk = aset[0]
dprint("msk =", hex(msk))
# There is exactly one msk' in {0x25, 0x26, 0x21} such that
# m_j XOR msk XOR msk' = 0x3C is a Type-A character
for msk1 in [0x25, 0x26, 0x21]:
# print("msk' =", hex(msk1))
iv1 = bytearray(iv)
iv1[j] ^= msk ^ msk1
if oracle(iv1, c) == 1:
x_j = 0x3C ^ iv1[j] # NB no xoring with masks
elif len_aset == 2:
dprint("Case 2")
# Last 4 bits are equal to 0x?0, 0x?1, 0x?2, 0x?3, 0x?4, 0x?5, 0x?7, 0x?8, 0x?b, 0x?e, 0x?f
for msk in [aset[0]]:
# There are 11 potential masks msk' for each msk in Aset, but symmetrical (0x20 vs 0x30)
mset = []
for msk1 in [0x2c, 0x2d, 0x2e, 0x2f, 0x28, 0x29, 0x2b, 0x24, 0x27, 0x22, 0x23,
0x3c, 0x3d, 0x3e, 0x3f, 0x38, 0x39, 0x3b, 0x34, 0x37, 0x32, 0x33]:
iv1 = bytearray(iv)
iv1[j] ^= msk ^ msk1
if oracle(iv1, c) == 1:
dprint("Found a Type-A match for msk'=", hex(msk1))
mset.append(msk1)
dprint("mset =", show_byteshex(mset))
if len(mset) == 1:
msk1 = mset[0]
x_j = 0x3C ^ iv[j] ^ msk ^ msk1
else:
# We should have two
msk1 = mset[0]
iv1 = bytearray(iv)
iv1[j] ^= msk ^ msk1 ^ 0x31
if oracle(iv1, c) == 1:
# Not this one, so the other
msk1 = mset[1]
x_j = 0x3C ^ iv[j] ^ msk ^ msk1
elif len_aset == 3:
dprint("Case 3")
# Last 4 bits are equal to 0x?6, 0x?C
dprint("Aset =", show_byteshex(aset))
x_j = 0x00
for msk in aset:
mset = []
# Out of the 6 combinations we expect exactly one Type-A outcome
for msk1 in [0x31, 0x2f]:
iv1 = bytearray(iv)
iv1[j] ^= msk ^ msk1
if oracle(iv1, c) == 1:
dprint("Found a Type-A match for msk'=", hex(msk1))
mset.append(msk1)
# If we found a Type-A
if msk1 == 0x31: # "&"
x_j = 0x26 ^ iv[j] ^ msk
else: # "<"
x_j = 0x3C ^ iv[j] ^ msk
else:
# Should not happen
assert (1 == 0)
return x_j
def break_block(IV, c):
dprint("break_block CT:", ["0x{:02x}".format(x) for x in c])
global found_typea
found_typea = []
m = bytearray(n)
iv = find_iv(IV, c)
dprint("At start, IV is = ", ["0x{:02x}".format(x) for x in IV])
dprint("FindIV returns iv =", ["0x{:02x}".format(x) for x in iv])
# for bytes j = 1 to n in c do x_j = FindXByte(C(i), iv, j)
for j in range(0, n):
x_j = FindXByte(c, iv, IV, j)
m_j = x_j ^ IV[j]
m[j] = m_j
return m
def debug_block(msg_blocks, m1, msg1, i):
dprint("m1=", m1)
dprint("m1=", show_byteshex(m1))
dprint("OK=", show_byteshex(msg1))
dprint(" ", [" " + chr(x) + " " for x in msg1])
dprint(" ", [(" ok " if x == y else " ** ") for x, y in zip(m1, msg1)])
nok = sum([x == y for x, y in zip(m1, msg1)])
print(f"{nok} correct out of {n}")
return n - nok # Number of errors
def main():
# SET UP
# |--------------|---------------|---------------|---------------|
# 1234567890123456789012345678901234567890123456789012345678901234
msg = "Now <Is>|the <Lime for> all good men to come to the aid of their"
iv = Cnv.fromhex("FEDCBA9876543210FEDCBA9876543210")
# key is a global variable available to the Oracle (but no peeking!)
print("INPUT...")
print(f"MSG='{msg}'")
print("KY=", Cnv.tohex(key))
print("IV=", Cnv.tohex(iv))
print("PT=", Cnv.tohex(msg.encode()))
# split into blocks - require exact multiple of block size
msg_blocks = [msg[i:i + n] for i in range(0, len(msg), n)]
ct = Cipher.encrypt_block(msg.encode(), key, iv, Cipher.Alg.AES128, Cipher.Mode.CBC)
print("CT=", Cnv.tohex(ct))
# array to accept decrypted output
mout = []
# split up ciphertext into blocks
ct_blocks = [ct[i:i + n] for i in range(0, len(ct), n)]
nblocks = len(ct_blocks)
print(f"Found {nblocks} blocks")
# Break the first block
ct1 = ct_blocks[0]
dprint("CT1=", Cnv.tohex(ct1))
dprint("BLOCK 1...")
m1 = break_block(iv, ct1)
mout.append(m1)
# Break subsequent blocks...
for i in range(1, nblocks):
dprint(f"\nBLOCK {i + 1}...")
m1 = break_block(ct_blocks[i - 1], ct_blocks[i])
mout.append(m1)
print("FINAL SOLUTION:")
print("OUT='" + ''.join([msg.decode() for msg in mout]) + "'")
totalerrs = 0
for i in range(0, nblocks):
print("BLOCK:", i + 1)
m1 = mout[i]
msg1 = msg_blocks[i].encode()
nerrs = debug_block(msg_blocks, m1, msg1, i)
totalerrs += nerrs
print(f"Found {totalerrs} errors.")
if __name__ == "__main__":
main()