"""
$Id: break-xmlenc-2.py $
$ Date: 2020-01-02 09:37 $
$ Version: 1.0.0 $
Copyright (C) 2020-21 David Ireland, DI Management Services Pty Ltd
<https://www.di-mgt.com.au>
SPDX-License-Identifier: MIT

[**] This has the simplifications from the paper:
1. The plaintext does not contain any "Type-A" character
   except for (possibly) the "<" character (so no entity references like &gt;)
2. Each encrypted block contains only incomplete elements
   (i.e. there exists no start tag followed by element content and an end tag).
[***] Additional simplifications in this program
3. Restriction (2) is ignored by our XML "parser".
   It will accept any XML elements, complete or not. See the code for `oracle()`.
"""
# pylint: disable=unused-wildcard-import
from cryptosyspki import *
# Debugging/logging
import logging
logging.basicConfig(level=logging.ERROR)
DEBUG = False
dprint = __builtins__.print if DEBUG else logging.debug
dprint("PKI version =", Gen.version())

# GLOBAL VARIABLES
n = 16  # Block length in bytes of AES
NCHARS = 128  # Number of ASCII characters
key = Cnv.fromhex("0123456789ABCDEFF0E1D2C3B4A59687")
# Set of Type-A characters: All from 0x00 to 0x1F except 0x09 (TAB), 0x0A (LF) and 0x0D (CR)
# plus 0x26 ('<') and 0x3C ('&').
type_a = list(range(0x00, 0x08 + 1)) + [0x0b, 0x0c, 0x0e, 0x0f] + list(range(0x10, 0x1f + 1)) + [0x26, 0x3c]
found_typea = []  # Global storage for indices to Type-A characters found in plaintext


def show_byteshex(a):
    return ["0x{:02x}".format(x) for x in a]


def Dec(iv, c):
    """ Decrypt ciphertext block c using global AES-128 key and given IV"""
    dt = Cipher.decrypt_block(c, key, iv, Cipher.Alg.AES128, Cipher.Mode.CBC)
    return dt


def oracle(iv, c):
    """
    O(C) = 1, if the server returns a 'security fault'
    O(C) = 0 otherwise.
    Simplified, based on assumptions [**] above.
    """
    reply = 0  # Presume innocent until proven guilty
    m = list(Dec(iv, c))
    # Is the padding byte valid?
    pad = m[n - 1]
    if pad < 0x1 or pad > 0x10:
        # print("Invalid padding byte")
        return 1
    # Are there any type A characters after stripping padding?
    if any(x in type_a for x in m[0:n - pad]):
        # print("Found Type A")
        reply = 1  # (simplified) parsing error
    return reply


def get_valid_padding_masks(iv, c):
    pset = []
    for j in range(0, 0x7F + 1):
        iv1 = bytearray(iv)
        iv1[n - 1] = iv1[n - 1] ^ j
        if oracle(iv1, c) == 0:
            pset.append(iv1[n - 1])  # Pset \union IV'_n
    return pset


def get_iv_with_padding_mask_01(pset, IV):
    assert (len(pset) == n)
    dprint("Pset  =", ["0x{:02x}".format(x) for x in pset])
    # GetIvWithPaddingMask01
    # padding masks Pset = {msk0x01, msk0x02, ..., msk0x10}
    # msk0x10 differs from others in the 4th bit
    list4thbit = [((x & 0x10) >> 4) for x in pset]
    dprint("4thbit  =", list4thbit)
    # List of indices to items of value 1
    idx1 = [idx for idx, val in enumerate(list4thbit) if val != 0]
    dprint("Indices of '1': ", idx1)
    # List of indices to items of value 0
    idx0 = [idx for idx, val in enumerate(list4thbit) if val == 0]
    dprint("Indices of '0': ", idx0)
    # One of these lists should contain exactly one element - this is msk0x10
    if len(idx1) == 1:
        msk0x10 = pset[idx1[0]]
    elif len(idx0) == 1:
        msk0x10 = pset[idx0[0]]
    else:
        assert (1 == 0)

    dprint("msk0x10 =", hex(msk0x10))
    iv = bytearray(IV)
    # NB *not* XOR'd with original IV
    iv[n - 1] = msk0x10 ^ 0x11
    return iv


def find_iv(IV, c):
    """Input: A ciphertext C = (C(i-1), C(i))
    Output: iv that is well-formed."""
    global found_typea
    iv = bytearray(IV)
    pset = []
    for i in range(1, 100):  # repeat...
        pset = get_valid_padding_masks(iv, c)
        pos = len(pset)
        dprint("|Pset| =", pos)
        assert (0 < pos <= n)
        if pos == n:
            break
        # if pos < 16 then we have a '<' at position pos, so
        # flip last bit of byte in position pos (1, 2, ..., 15)
        iv[pos - 1] ^= 0x01
        # and save its index because we've already decrypted it
        found_typea.append(pos - 1)
        # repeat until...

    iv = get_iv_with_padding_mask_01(pset, iv)

    return iv


def ComputeSetAset(iv, c, j):
    """Input: C = (iv, c), j in {0,...,n-1)
    Output: Set Aset of masks
    """
    # print("At start of ComputeSetAset iv =", show_byteshex(iv))
    aset = []
    for R in range(0, 7 + 1):
        msk = (R << 4) & 0xFF  # 0xR0
        # XOR the jth byte of iv with msk
        iv1 = bytearray(iv)  # Caution: Python shallow copying!
        iv1[j] ^= msk
        # Test oracle for this new iv'
        if oracle(iv1, c) == 1:
            aset.append(msk)
    return aset


def FindXByte(c, iv, IV0, j):
    """Input:
    c Single-block ciphertext
    iv such that C = (iv, c) is well-formed
    IV0 original IV
    j index in range [0, n)
    Output: j-th byte of x = Dec(k, c)
    """
    x_j = 0
    dprint(f"Calling FindXByte for j = {j}")

    # Have we already detected a Type-A character here?
    # -- in this simplified case, it is always a '<'
    if j in found_typea:
        dprint("Already found: '<'")
        x_j = ord('<') ^ IV0[j]
        return x_j

    dprint("iv =", show_byteshex(iv))
    if j == n - 1:
        # special case for nth byte
        x_j = 0x01 ^ iv[n - 1]
        dprint("Special case for last byte x(n-1) =", hex(x_j))
        return x_j

    aset = ComputeSetAset(iv, c, j)
    len_aset = len(aset)
    dprint(f"|Aset| = {len_aset}")
    # Only three cases possible
    if len_aset == 1:
        dprint("Case 1")
        # Last 4 bits are equal to 0x?9, 0x?A or 0x?D
        msk = aset[0]
        dprint("msk =", hex(msk))
        # There is exactly one msk' in {0x25, 0x26, 0x21} such that
        # m_j XOR msk XOR msk' = 0x3C is a Type-A character
        for msk1 in [0x25, 0x26, 0x21]:
            # print("msk' =", hex(msk1))
            iv1 = bytearray(iv)
            iv1[j] ^= msk ^ msk1
            if oracle(iv1, c) == 1:
                x_j = 0x3C ^ iv1[j]  # NB no xoring with masks

    elif len_aset == 2:
        dprint("Case 2")
        # Last 4 bits are equal to 0x?0, 0x?1, 0x?2, 0x?3, 0x?4, 0x?5, 0x?7, 0x?8, 0x?b, 0x?e, 0x?f
        for msk in [aset[0]]:
            # There are 11 potential masks msk' for each msk in Aset, but symmetrical (0x20 vs 0x30)
            mset = []
            for msk1 in [0x2c, 0x2d, 0x2e, 0x2f, 0x28, 0x29, 0x2b, 0x24, 0x27, 0x22, 0x23,
                         0x3c, 0x3d, 0x3e, 0x3f, 0x38, 0x39, 0x3b, 0x34, 0x37, 0x32, 0x33]:
                iv1 = bytearray(iv)
                iv1[j] ^= msk ^ msk1
                if oracle(iv1, c) == 1:
                    dprint("Found a Type-A match for msk'=", hex(msk1))
                    mset.append(msk1)
            dprint("mset =", show_byteshex(mset))
            if len(mset) == 1:
                msk1 = mset[0]
                x_j = 0x3C ^ iv[j] ^ msk ^ msk1
            else:
                # We should have two
                msk1 = mset[0]
                iv1 = bytearray(iv)
                iv1[j] ^= msk ^ msk1 ^ 0x31
                if oracle(iv1, c) == 1:
                    # Not this one, so the other
                    msk1 = mset[1]
                x_j = 0x3C ^ iv[j] ^ msk ^ msk1

    elif len_aset == 3:
        dprint("Case 3")
        # Last 4 bits are equal to 0x?6, 0x?C
        dprint("Aset =", show_byteshex(aset))
        x_j = 0x00
        for msk in aset:
            mset = []
            # Out of the 6 combinations we expect exactly one Type-A outcome
            for msk1 in [0x31, 0x2f]:
                iv1 = bytearray(iv)
                iv1[j] ^= msk ^ msk1
                if oracle(iv1, c) == 1:
                    dprint("Found a Type-A match for msk'=", hex(msk1))
                    mset.append(msk1)
                    # If we found a Type-A
                    if msk1 == 0x31:  # "&"
                        x_j = 0x26 ^ iv[j] ^ msk
                    else:  # "<"
                        x_j = 0x3C ^ iv[j] ^ msk

    else:
        # Should not happen
        assert (1 == 0)

    return x_j


def break_block(IV, c):
    dprint("break_block CT:", ["0x{:02x}".format(x) for x in c])
    global found_typea
    found_typea = []
    m = bytearray(n)
    iv = find_iv(IV, c)
    dprint("At start, IV is =  ", ["0x{:02x}".format(x) for x in IV])
    dprint("FindIV returns iv =", ["0x{:02x}".format(x) for x in iv])

    # for bytes j = 1 to n in c do x_j = FindXByte(C(i), iv, j)
    for j in range(0, n):
        x_j = FindXByte(c, iv, IV, j)
        m_j = x_j ^ IV[j]
        m[j] = m_j

    return m


def debug_block(msg_blocks, m1, msg1, i):
    dprint("m1=", m1)
    dprint("m1=", show_byteshex(m1))
    dprint("OK=", show_byteshex(msg1))
    dprint("   ", [" " + chr(x) + "  " for x in msg1])
    dprint("   ", [(" ok " if x == y else " ** ") for x, y in zip(m1, msg1)])
    nok = sum([x == y for x, y in zip(m1, msg1)])
    print(f"{nok} correct out of {n}")
    return n - nok  # Number of errors


def main():
    # SET UP
    #      |--------------|---------------|---------------|---------------|
    #      1234567890123456789012345678901234567890123456789012345678901234
    msg = "Now <Is>|the <Lime for> all good men to come to the aid of their"
    iv = Cnv.fromhex("FEDCBA9876543210FEDCBA9876543210")
    # key is a global variable available to the Oracle (but no peeking!)
    print("INPUT...")
    print(f"MSG='{msg}'")
    print("KY=", Cnv.tohex(key))
    print("IV=", Cnv.tohex(iv))
    print("PT=", Cnv.tohex(msg.encode()))
    # split into blocks - require exact multiple of block size
    msg_blocks = [msg[i:i + n] for i in range(0, len(msg), n)]
    ct = Cipher.encrypt_block(msg.encode(), key, iv, Cipher.Alg.AES128, Cipher.Mode.CBC)
    print("CT=", Cnv.tohex(ct))

    # array to accept decrypted output
    mout = []

    # split up ciphertext into blocks
    ct_blocks = [ct[i:i + n] for i in range(0, len(ct), n)]
    nblocks = len(ct_blocks)
    print(f"Found {nblocks} blocks")

    # Break the first block
    ct1 = ct_blocks[0]
    dprint("CT1=", Cnv.tohex(ct1))
    dprint("BLOCK 1...")
    m1 = break_block(iv, ct1)
    mout.append(m1)

    # Break subsequent blocks...
    for i in range(1, nblocks):
        dprint(f"\nBLOCK {i + 1}...")
        m1 = break_block(ct_blocks[i - 1], ct_blocks[i])
        mout.append(m1)

    print("FINAL SOLUTION:")
    print("OUT='" + ''.join([msg.decode() for msg in mout]) + "'")
    totalerrs = 0
    for i in range(0, nblocks):
        print("BLOCK:", i + 1)
        m1 = mout[i]
        msg1 = msg_blocks[i].encode()
        nerrs = debug_block(msg_blocks, m1, msg1, i)
        totalerrs += nerrs
    print(f"Found {totalerrs} errors.")


if __name__ == "__main__":
    main()