"""
$Id: break-xmlenc.py $
$ Date: 2020-01-02 09:37 $
$ Version: 1.0.0 $
Copyright (C) 2020-21 David Ireland, DI Management Services Pty Ltd
<https://www.di-mgt.com.au>
SPDX-License-Identifier: MIT

Implementation of the Toy Example in 'How to break XML encryption'
by T Jager and J Somorovsky. 
In: Proceedings of the 18th ACM Conference on Computer and Communications Security, 
CCS 2011, Chicago, Illinois, USA, October 17-21, 2011
<http://www.nds.rub.de/media/nds/veroeffentlichungen/2011/10/22/HowToBreakXMLenc.pdf>
"""

# pylint: disable=unused-wildcard-import
from cryptosyspki import *
print("PKI version =", Gen.version())

# GLOBAL VARIABLES
n = 16  # Block length in bytes of AES
NCHARS = 128    # Number of ASCII characters
w = 0x0     # Our only Type-A character
type_a = [w]    # Set of Type-A characters
# AES-128 key known by the Oracle and the setup code in main, but not known by ``break_it_all``
key = Cnv.fromhex("0123456789ABCDEFF0E1D2C3B4A59687")


def Dec(iv, c):
    """ Decrypt ciphertext block c using global AES-128 key and given IV"""
    dt = Cipher.decrypt_block(c, key, iv, Cipher.Alg.AES128, Cipher.Mode.CBC)
    return dt


def Oracle(iv, c):
    """ O(C) = 1, if the plaintext m = Dec_cbc(k,C) contains
    only Type-B characters, else 0. """
    reply = 0
    m = list(Dec(iv, c))
    # print(m)
    # Are there any type A characters in the list?
    if any(x in type_a for x in m):
        # print("Found Type A")
        reply = 0
    else:
        # print("All Type B")
        reply = 1
    return reply


def xor_bytes(a, b):
    assert(len(a) == len(b))
    return bytes(x ^ y for x, y in zip(a, b))


def find_iv(iv, ct1):
    # First query the oracle whether O((IV,C(1))) = 1.
    # In this case we can set IV' := IV.
    # Otherwise we set IV' to a random bit string.

    # Does O((IV, C_1) = 1?
    iv1 = iv
    if 1 == Oracle(iv1, ct1):
        return iv1
    # Else try some random bit strings for IV'
    for i in range(0,1000):  # Safer than ``while 1``!
        iv1 = Rng.bytestring(n)
        # print("Random iv1 =", Cnv.tohex(iv1))
        if 1 == Oracle(iv, ct1):
            break
    return iv1


def break_block(iv, ct1):
    # 1. Use the oracle to compute an initialization vector IV'
    #    such that C' = (IV',C(1)) is well-formed.
    iv1 = find_iv(iv, ct1)
    # To recover x_j
    # we modify the initialization vector IV′
    # by XOR-ing a byte-mask msk to the j-th byte of IV'
    m = bytearray(n)
    iv2 = None
    for j in range(0, n):
        # Recover x_j
        for msk in range(1, NCHARS): # for msk = 1 to 127 (all ASCII characters)
            # repeat until O((IV'', C_1)) = 0
            # [NB typo in line 5 of Algorithm 1 in paper: change ``= 1`` to ``= 0``]
            iv2 = bytearray(iv1)
            iv2[j] = iv1[j] ^ msk
            # print("IV'' =", Cnv.tohex(iv2))
            if (Oracle(iv2, ct1) == 0):
                break
        
        xj = w ^ iv2[j]  # x_j := w XOR IV''(j)
        # print(f"x={hex(xj)}")
        mj = iv[j] ^ xj  # m_j = IV_j XOR x_j
        # print(f"m[{j}]={hex(mj)}")
        m[j] = mj
    return m


def break_it_all(iv, ct):
    """Given the ciphertext and IV, use the Oracle to find the plaintext."""
    # Split CT into chunks c1, c2, ...
    ct_blocks = [ct[i:i+n] for i in range(0, len(ct), n)]
    nblocks = len(ct_blocks)
    print(f"Found {nblocks} blocks")
    m = []
    # Break the first block
    ct1 = ct_blocks[0]
    print("CT1=", Cnv.tohex(ct1))
    m1 = break_block(iv, ct1)
    m.append(m1)
    print("PT1=", m[0])
    # Now break any subsequent blocks
    # passing CT(i-1) as the IV for the block CT(i)
    for i in range(1,nblocks):
        m1 = break_block(ct_blocks[i-1], ct_blocks[i])
        m.append(m1)

    return m


def main():
    # SET UP
    msg = "Now is the time for all good men to come to the aid of their par"
    iv = Cnv.fromhex("FEDCBA9876543210FEDCBA9876543210")
    # key is a global variable available to the Oracle (but no peeking!)
    print("KY=", Cnv.tohex(key))
    print("IV=", Cnv.tohex(iv))
    ct = Cipher.encrypt_block(msg.encode(), key, iv, Cipher.Alg.AES128, Cipher.Mode.CBC)
    print("CT=", Cnv.tohex(ct))
    # What we expect
    print("OK= C3153108A8DD340C0BCB1DFE8D25D2320EE0E66BD2BB4A313FB75C5638E9E177211FC26A1FF51CE35741B76A77DB6DE27435A4C79E56F29BB12B595404C222F1")

    # GO AHEAD AND BREAK IT...
    m = break_it_all(iv, ct)
    print(m)
    m1 = [x.decode() for x in m]
    print(m1)

    # Check against original message
    # Split msg into chunks m1, m2, ...
    msg_blocks = [msg[i:i+n] for i in range(0, len(msg), n)]
    print("Correct:")
    print(msg_blocks)


if __name__ == "__main__":
    main()