# @file spx_makesigs100.py (2023-03-18T10:13Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0

"""Reproduce all 100 SPHINCS+-SHA2-128f-simple test cases."""

from spx_sha256 import *
from spx_adrs import Adrs
from spx_util import *
import re  # For manipulation of sig string
import json

"""
Reproduce all 100 SPHINCS+-SHA2-128f-simple test cases 0..99 from
NIST-PQ-Submission-SPHINCS-20201001/KAT/sphincs-sha256-128f-simple/PQCsignKAT_64.rsp

CAUTION: Many hard-coded constants! 
Lots of debugging statements. Set DEBUG=False to turn off.
NOTE: the value of `optrand` can only be found by running the implementation code.
"""

# DEBUGGING
DEBUG = False
DPRINT = print if DEBUG else lambda *a, **kk: None


def make_sig(count, PKseed, PKroot, SKseed, SK_prf, msg, optrand, expected_hash):
    # CONSTANTS
    SPX_DGST_BYTES = 34
    SPX_FORS_TREES = 33
    SPX_FORS_HEIGHT = 6
    SPX_WOTS_W = 16
    w = SPX_WOTS_W
    t = (1 << SPX_FORS_HEIGHT)  # 2^6 = 64
    SPX_WOTS_LEN = 35
    SPX_D = 22
    SPX_HT_HEIGHT = 3
    SPX_HT_LEAVES = (1 << SPX_HT_HEIGHT)  # 2^3 = 8

    print(f"\nAbout to compute SPHINCS+ signature for TEST CASE {count}...")

    # Start composing the signature as a hex-encoded string with line breaks and comments
    sig = ""

    # Compute the randomizer R of 16 bytes using hex strings
    R = PRF_msg(SK_prf, optrand, msg)
    DPRINT("R =", R)
    sig += R + " # R\n"

    # Compute H_msg from the message, public key and randomizer.
    h = SHA256(R + PKseed + PKroot + msg)
    h_msg = MGF1_SHA256(h, SPX_DGST_BYTES)  # 34
    DPRINT("H_msg:", h_msg)

    # Split into mhash (25 bytes), tree address (8 bytes) and idx_leaf (1 byte)
    DPRINT("Split up H_msg...")
    mhash = h_msg[:50]
    tree_hex = h_msg[50:66]
    leaf_hex = h_msg[66:68]
    DPRINT(f"mhash='{mhash}'")
    DPRINT(f"tree='{tree_hex}', leaf='{leaf_hex}'")
    # Decode tree address and leaf index into integers
    tree_addr = int(tree_hex, 16) & 0x7fffffffffffffff  # 63 bits
    DPRINT("tree_addr = 0x" + format(tree_addr, f'08x'))
    idx_leaf = int(leaf_hex, 16) & 0x7  # 3 bits
    DPRINT(f"tree_addr = {idx_leaf}")

    # Interpret 25-byte mhash as 33 * 6-bit unsigned integers.
    indices = [0] * SPX_FORS_TREES
    m = bytes.fromhex(mhash)
    DPRINT(f"mhash={m.hex()}")
    offset = 0
    for i in range(SPX_FORS_TREES):
        indices[i] = 0
        for j in range(SPX_FORS_HEIGHT):
            indices[i] ^= ((m[offset >> 3] >> (offset & 0x7)) & 0x1) << j
            offset += 1

    DPRINT("message_to_indices:\n", [m for m in indices], sep='')

    # Compute all 33 FORS signature sk values using the indices
    # (we'll output these later)
    fors_sig_sk = []
    # Set up ADRS object
    adrs = Adrs(Adrs.FORS_TREE, layer=0)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(idx_leaf)

    for i in range(SPX_FORS_TREES):
        treeindex = i * t + indices[i]
        adrs.setTreeIndex(treeindex)
        DPRINT(f"ADRS={adrs.toHex()}")
        sk = PRF(SKseed, adrs.toHex())
        DPRINT(f"fors_sig_sk[{i}]={sk}")
        fors_sig_sk.append(sk)

    # Compute the authpaths and root values for each of the k FORS trees
    roots = []
    # Compute FORS sk and pk values for each tree (i = 0,32)
    for i in range(SPX_FORS_TREES):
        # Set up ADRS object
        adrs = Adrs(Adrs.FORS_TREE, layer=0)
        adrs.setTreeAddress(tree_addr)
        adrs.setKeyPairAddress(idx_leaf)
        leaves = []
        for j in range(t):
            # Note this computes all the sk's including the one at indices[i]
            treeindex = i * t + j
            adrs.setTreeIndex(treeindex)
            DPRINT(f"ADRS={adrs.toHex()}")
            sk = PRF(SKseed, adrs.toHex())
            DPRINT(f"fors_sk[{i}][{j}]={sk}")
            pk = F(PKseed, adrs.toHex(), sk)
            DPRINT(f"fors_pk[{i}][{j}]={pk}")
            leaves.append(pk)

        DPRINT("leaves=", leaves, sep='\n')
        # Compute the root value for this FORS tree
        adrs = Adrs(Adrs.FORS_TREE, layer=0)
        adrs.setTreeAddress(tree_addr)
        adrs.setKeyPairAddress(idx_leaf)
        DPRINT(f"ADRS={adrs.toHex()}")
        root = hash_root(leaves, adrs, PKseed, i * t)
        DPRINT(f"root[{i}]={root}")
        roots.append(root)
        # and the authpath for indices[i]
        idx = indices[i]
        DPRINT(f"i={i} idx={idx}")
        auth = authpath(leaves, adrs, PKseed, idx, i * t)
        DPRINT(f"fors_auth_path[{i}]:")
        [DPRINT(a) for a in auth]
        # Output the sig_sk and authpath to the signature value
        sig += fors_sig_sk[i] + format(f" # fors_sig_sk[{i}]\n")
        sig += auth[0] + format(f" # fors_auth_path[{i}]\n")
        sig += "\n".join(auth[1:]) + "\n"

    DPRINT(sig)

    # Compute the FORS public key given the roots of the k FORS trees.
    DPRINT("roots:")
    [DPRINT(r) for r in roots]
    adrs = Adrs(Adrs.FORS_ROOTS, layer=0)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(idx_leaf)
    DPRINT(f"ADRS={adrs.toHex()}")
    fors_pk = T_len(PKseed, adrs.toHex(), "".join(roots))
    DPRINT(f"fors_pk={fors_pk}")

    def wots_chain(msghex, show_csum=False):
        # Split hex string into list of 4-bit nibbles
        # (Cheat: we can just split hex string into separate digits)
        assert w == 16  # But check w really is 16
        mymsg = [int(x, 16) for x in msghex]
        # Compute csum
        csum = 0
        for msgi in range(len(mymsg)):
            csum += int(w - 1 - mymsg[msgi])
        csum &= 0xfff   # truncate to 12 bits
        if show_csum:
            DPRINT(f"csum={csum:03x}")
        mymsg.append((csum >> 8) & 0xF)
        mymsg.append((csum >> 4) & 0xF)
        mymsg.append((csum >> 0) & 0xF)
        return mymsg

    # Input FORS public key to first WOTS signature
    wots_input = fors_pk
    # Loop for each of 22 subtrees in the HT
    for layer in range(SPX_D):
        DPRINT(f"input to HT at layer {layer}={wots_input}")
        DPRINT(f"tree_addr={tree_addr:x} idx_leaf={idx_leaf}")
        m = wots_chain(wots_input, False)
        DPRINT(m)
        DPRINT([hex(x) for x in m])
        DPRINT(f"len={len(m)}")

        # Compute the next WOTS signature.
        # Set up ADRS object
        adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
        adrs.setTreeAddress(tree_addr)
        adrs.setKeyPairAddress(idx_leaf)
        DPRINT(f"ADRS base={adrs.toHex()}")

        htsigs = []
        for idx in range(SPX_WOTS_LEN):  # 35
            DPRINT(f"Generate WOTS+ private key for i = {idx}")
            # sk = PRF(SK.seed, ADRS)
            adrs.setChainAddress(idx)
            adrs_c = adrs.toHex()
            DPRINT(f"ADRS={adrs_c}")
            sk = PRF(SKseed, adrs_c)
            DPRINT(f"sk={sk}")

            # Compute F^m_i(sk)
            mi = m[idx]
            DPRINT(f"m[{idx}]={mi}")
            x = sk
            adrs_ht = Adrs.fromHex(adrs.toHex())
            for i in range(mi):
                adrs_ht.setHashAddress(i)
                adrs_c = adrs_ht.toHex()
                DPRINT(f"i={i} ADRS={adrs_c}")
                DPRINT(f"in={x}")
                x = F(PKseed, adrs_c, x)
                DPRINT(f"F(PK.seed, ADRS, in)={x}")

            DPRINT(f"ht_sig:{x}")
            htsigs.append(x)

        # Output this ht_sig (560 bytes) to the signature value
        sig += htsigs[0] + format(f" # ht_sig[{layer}]\n")
        sig += "\n".join(htsigs[1:]) + "\n"

        leaves = []
        # Compute all leaves of subtree at this layer
        for this_leaf in range(SPX_HT_LEAVES):
            DPRINT(f"this_leaf={this_leaf}")
            adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
            adrs.setTreeAddress(tree_addr)
            adrs.setKeyPairAddress(this_leaf)
            DPRINT(adrs.toHex())
            heads = ""  # concatenation of heads of WOTS+ chains
            for chainaddr in range(SPX_WOTS_LEN):
                adrs.setChainAddress(chainaddr)
                adrs_hex = adrs.toHex()
                sk = PRF(SKseed, adrs_hex)
                DPRINT(f"sk[{chainaddr}]={sk}")
                pk = chain(sk, 0, w - 1, PKseed, adrs_hex,
                           showdebug=(DEBUG and (chainaddr < 2 or chainaddr == (SPX_WOTS_LEN-1))))
                DPRINT(f"pk={pk}")
                heads += pk

            DPRINT(f"Input to thash:\n{heads}")
            # for thash,
            wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=layer)
            wots_pk_adrs.setTreeAddress(tree_addr)
            wots_pk_adrs.setKeyPairAddress(this_leaf)
            wots_pk_addr_hex = wots_pk_adrs.toHex()
            DPRINT(f"wots_pk_addr={wots_pk_addr_hex}")
            leaf = T_len(PKseed, wots_pk_addr_hex, heads)
            DPRINT(f"leaf[{leaf}]={leaf}")
            leaves.append(leaf)

        DPRINT(leaves)

        # Compute the root node of Merkle tree using H
        # Start with 8 leaf values in array
        adrs = Adrs(Adrs.TREE, layer=layer)
        adrs.setTreeAddress(tree_addr)
        DPRINT(f"ADRS={adrs.toHex()}")
        root = hash_root(leaves, adrs, PKseed)
        DPRINT(f"root={root}")

        # Compute the authentication path from leaf_idx
        DPRINT(f"Computing authpath for leaf index {idx_leaf}...")
        auth = authpath(leaves, adrs, PKseed, idx_leaf)
        DPRINT("authpath:")
        [DPRINT(a) for a in auth]

        # Output this authpath to the signature value
        sig += auth[0] + format(f" # ht_auth_path[{layer}]\n")
        sig += "\n".join(auth[1:]) + "\n"

        # Set next wots_input to root and change tree_addr for next layer
        wots_input = root
        idx_leaf = tree_addr & 0x7  # (2^3 - 1)
        tree_addr >>= 3
        # Loop for next higher subtree...

    # At the end the final root value MUST equal PK.root
    print(f"Final root={wots_input}")
    print(f"Expecting ={PKroot}")
    assert wots_input.lower() == PKroot.lower()

    # Print out and check the signature
    print(f"sig:\n{sig}")
    print("sig lines =", sig.count('\n'), "(expecting 1068)")
    # Strip sig string down to pure hex digits
    sighex = sig
    sighex = re.sub(r'\s+#.*?$', '', sighex, flags=re.MULTILINE)
    sighex = sighex.replace('\n', '')
    print("sighex len =", len(sighex), "(expecting 17088 x 2 = 34176)")
    # Compute SHA-256 of signature
    hash_sig = SHA256(sighex)
    print(f"SHA256(sighex)={hash_sig}")
    # Check hash matches what we expect
    print(f"Expected      ={expected_hash}")
    assert hash_sig == expected_hash


if __name__ == '__main__':
    # Read in json file with input data for 100 test cases
    with open('spx_inputkat100.json', 'r') as f:
        katall = json.load(f)
    # Do each test case
    for k in katall:
        # def make_sig(count, PKseed, PKroot, SKseed, SK_prf, msg, optrand, expected_hash):
        make_sig(k['count'], k['PKseed'], k['PKroot'], k['SKseed'], k['SK_prf'], k['msg'], k['optrand'], k['hash'])

    print("ALL DONE.")