# @file spx_makesig.py (2023-03-16T14:57Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0

"""Do it all to make a SPHINCS+ signature in one go."""

from spx_sha256 import *
from spx_adrs import Adrs
from spx_util import *
import re  # For manipulation of sig string

"""
Reproduce SPHINCS+-SHA2-128f-simple test cases 0 and 1 from
NIST-PQ-Submission-SPHINCS-20201001/KAT/sphincs-sha256-128f-simple/PQCsignKAT_64.rsp

CAUTION: Many hard-coded constants! And hard-coded assert statements.
Lots of debugging statements. Set DEBUG=False to turn off.
NOTE: the value of `optrand` can only be found by running the implementation code.
"""

# DEBUGGING
DEBUG = False
DPRINT = print if DEBUG else lambda *a, **k: None

##################
# SELECT TEST CASE
##################
# Edit this...
test_case_to_do = 0

if test_case_to_do == 0:
    # INPUT FOR TEST CASE 0
    test_title = 'TEST CASE 0'
    PKseed = 'B505D7CFAD1B497499323C8686325E47'
    PKroot = '4FDFA42840C84B1DDD0EA5CE46482020'
    SKseed = '7C9935A0B07694AA0C6D10E4DB6B1ADD'
    SK_prf = '2fd81a25ccb148032dcd739936737f2d'
    msg = \
        'D81C4D8D734FCBFBEADE3D3F8A039FAA2A2C9957E835AD55B22E75BF57BB556AC8'
    optrand = '33b3c07507e4201748494d832b6ee2a6'
    # Expected hash of signature
    expected_hash = 'ea2bef5299332943d7301a883aa6c1caba08975b7924ed581709b5b1c88beaad'
    TESTCASE0 = True

elif test_case_to_do == 1:
    # INPUT FOR TEST CASE 1
    TESTCASE0 = False
    test_title = 'TEST CASE 1'
    PKseed = 'D5A45A4CED06403C5557E87113CB30EA'
    PKroot = '8546AD883DDC43325A606C8B940C5EB1'
    SKseed = '4B622DE1350119C45A9F2E2EF3DC5DF5'
    SK_prf = '0A759D138CDFBD64C81CC7CC2F513345'
    msg= \
        '225D5CE2CEAC61930A07503FB59F7C2F936A3E075481DA3CA299A80F8C5DF9223A073E7B90E02EBF98CA2227EBA38C1AB2568209E46DBA961869C6F83983B17DCD49'
    optrand = '08e25538484cd7f1613248fe6c9f6b4e'
    # Expected hash of signature
    expected_hash = '5224c48f065b48afaa5ebb14bd76aecfc503cecd48a071bba5e6ea366775c2a9'

else:
    print("**INVALID TEST CASE NUMBER!!")
    exit(1)

# CONSTANTS
SPX_DGST_BYTES=34
SPX_FORS_TREES = 33
SPX_FORS_HEIGHT = 6
t = 64
w = 16
SPX_WOTS_LEN = 35
SPX_TREE_HEIGHT = 22

print(f"About to compute SPHINCS+ signature for {test_title}...")

# Start composing the signature as a hex-encoded string with line breaks and comments
sig = ""

# Compute the randomizer R of 16 bytes using hex strings
R = PRF_msg(SK_prf, optrand, msg)
DPRINT("R =", R)
# R = b77b5397031e67eb585dba86b10b710b
if TESTCASE0: assert R == 'b77b5397031e67eb585dba86b10b710b'
sig += R + " # R\n"

# Compute H_msg from the message, public key and randomizer.
h = SHA256(R + PKseed + PKroot + msg)
h_msg = MGF1_SHA256(h, 34)  # SPX_DGST_BYTES
DPRINT("H_msg:", h_msg)
if TESTCASE0: assert h_msg == '5b7eb772aecf04c74af07d9d9c1c1f8d3a90dcda00d5bab1dc28daecdc86eb87611e'

# Split into mhash (25 bytes), tree address (8 bytes) and idx_leaf (1 byte)
DPRINT("Split up H_msg...")
mhash = h_msg[:50]
tree_hex = h_msg[50:66]
leaf_hex = h_msg[66:68]
DPRINT(f"mhash='{mhash}'")
DPRINT(f"tree='{tree_hex}', leaf='{leaf_hex}'")
# Decode tree address and leaf index into integers
tree_addr = int(tree_hex, 16) & 0x7fffffffffffffff  # 63 bits
DPRINT("tree_addr = 0x" + format(tree_addr, f'08x'))
idx_leaf = int(leaf_hex, 16) & 0x7  # 3 bits
DPRINT(f"tree_addr = {idx_leaf}")
if TESTCASE0: assert tree_addr == 0x28daecdc86eb8761
if TESTCASE0: assert idx_leaf == 6

# Interpret 25-byte mhash as 33 * 6-bit unsigned integers.
indices = [0] * SPX_FORS_TREES
m = bytes.fromhex(mhash)
DPRINT(f"mhash={m.hex()}")
offset = 0
for i in range(SPX_FORS_TREES):
    indices[i] = 0
    for j in range(SPX_FORS_HEIGHT):        
        indices[i] ^= ((m[offset >> 3] >> (offset & 0x7)) & 0x1) << j
        offset += 1
       
DPRINT("message_to_indices:\n", [m for m in indices], sep='')
if TESTCASE0: assert indices[0] == 27 and indices[32] == 28

# Compute all 33 FORS signature sk values using the indices
# (we'll output these later)
fors_sig_sk = []
# Set up ADRS object
adrs = Adrs(Adrs.FORS_TREE, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)

for i in range(SPX_FORS_TREES):
    treeindex = i * t + indices[i]
    adrs.setTreeIndex(treeindex)
    DPRINT(f"ADRS={adrs.toHex()}")
    sk = PRF(SKseed, adrs.toHex())
    DPRINT(f"fors_sig_sk[{i}]={sk}")
    fors_sig_sk.append(sk)

if TESTCASE0: assert fors_sig_sk[0] == '8c9f8091d1a1edbb6a8a041343c6e5c0'
if TESTCASE0: assert fors_sig_sk[32] == '446d9fc66808fcc5e0d47c0c381c7f9e'

# Compute the authpaths and root values for each of the k FORS trees
roots = []
# Compute FORS sk and pk values for each tree (i = 0,32)
for i in range(SPX_FORS_TREES):
    # Set up ADRS object
    adrs = Adrs(Adrs.FORS_TREE, layer=0)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(idx_leaf)
    leaves = []
    for j in range(t):
        # Note this computes all the sk's including the one at indices[i]
        treeindex = i * t + j
        adrs.setTreeIndex(treeindex)
        DPRINT(f"ADRS={adrs.toHex()}")
        sk = PRF(SKseed, adrs.toHex())
        DPRINT(f"fors_sk[{i}][{j}]={sk}")
        pk = F(PKseed, adrs.toHex(), sk)
        DPRINT(f"fors_pk[{i}][{j}]={pk}")
        leaves.append(pk)

    DPRINT("leaves=", leaves, sep='\n')
    # Compute the root value for this FORS tree
    adrs = Adrs(Adrs.FORS_TREE, layer=0)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(idx_leaf)
    DPRINT(f"ADRS={adrs.toHex()}")
    root = hash_root(leaves, adrs, PKseed, i * t)
    DPRINT(f"root[{i}]={root}")
    roots.append(root)
    # and the authpath for indices[i]
    idx = indices[i]
    DPRINT(f"i={i} idx={idx}")
    auth = authpath(leaves, adrs, PKseed, idx, i * t)
    DPRINT(f"fors_auth_path[{i}]:")
    [DPRINT(a) for a in auth]
    # Output the sig_sk and authpath to the signature value
    sig += fors_sig_sk[i] + format(f" # fors_sig_sk[{i}]\n")
    sig += auth[0] + format(f" # fors_auth_path[{i}]\n")
    sig += "\n".join(auth[1:]) + "\n"

DPRINT(sig)

# Compute the FORS public key given the roots of the k FORS trees.
DPRINT("roots:")
[DPRINT(r) for r in roots]
adrs = Adrs(Adrs.FORS_ROOTS, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
DPRINT(f"ADRS={adrs.toHex()}")
fors_pk = T_len(PKseed, adrs.toHex(), "".join(roots))
DPRINT(f"fors_pk={fors_pk}")


def wots_chain(msghex, show_csum=False):
    # Split hex string into list of 4-bit nibbles
    # (Cheat: we can just split hex string into separate digits)
    mymsg = [int(x, 16) for x in msghex]
    # Compute csum
    csum = 0
    for msgi in range(len(mymsg)):
        csum += int(w - 1 - mymsg[msgi])
    csum &= 0xfff   # truncate to 12 bits
    if show_csum: DPRINT(f"csum={csum:03x}")
    mymsg.append((csum >> 8) & 0xF)
    mymsg.append((csum >> 4) & 0xF)
    mymsg.append((csum >> 0) & 0xF)
    return mymsg


# Input FORS public key to first WOTS signature
wots_input = fors_pk
# Loop for each of 22 subtrees in the HT
for layer in range(SPX_TREE_HEIGHT):
    DPRINT(f"input to HT at layer {layer}={wots_input}")
    DPRINT(f"tree_addr={tree_addr:x} idx_leaf={idx_leaf}")
    m = wots_chain(wots_input, False)
    DPRINT(m)
    DPRINT([hex(x) for x in m])
    DPRINT(f"len={len(m)}")

    # Compute the next WOTS signature.
    # Set up ADRS object
    adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(idx_leaf)
    DPRINT(f"ADRS base={adrs.toHex()}")

    htsigs = []
    for idx in range(SPX_WOTS_LEN):  # 35
        DPRINT(f"Generate WOTS+ private key for i = {idx}")
        # sk = PRF(SK.seed, ADRS)
        adrs.setChainAddress(idx)
        adrs_c = adrs.toHex()
        DPRINT(f"ADRS={adrs_c}")
        sk = PRF(SKseed, adrs_c)
        DPRINT(f"sk={sk}")

        # Compute F^m_i(sk)
        mi = m[idx]
        DPRINT(f"m[{idx}]={mi}")
        x = sk
        adrs_ht = Adrs.fromHex(adrs.toHex())
        for i in range(mi):
            adrs_ht.setHashAddress(i)
            adrs_c = adrs_ht.toHex()
            DPRINT(f"i={i} ADRS={adrs_c}")
            DPRINT(f"in={x}")
            x = F(PKseed, adrs_c, x)
            DPRINT(f"F(PK.seed, ADRS, in)={x}")

        DPRINT(f"ht_sig:{x}")
        htsigs.append(x)

    # Output this ht_sig (560 bytes) to the signature value
    sig += htsigs[0] + format(f" # ht_sig[{layer}]\n")
    sig += "\n".join(htsigs[1:]) + "\n"

    leaves = []
    # Compute all leaves of subtree at this layer
    for this_leaf in range(8):
        DPRINT(f"this_leaf={this_leaf}")
        adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
        adrs.setTreeAddress(tree_addr)
        adrs.setKeyPairAddress(this_leaf)
        DPRINT(adrs.toHex())
        heads = ""  # concatenation of heads of WOTS+ chains
        for chainaddr in range(35):
            adrs.setChainAddress(chainaddr)
            adrs_hex = adrs.toHex()
            sk = PRF(SKseed, adrs_hex)
            DPRINT(f"sk[{chainaddr}]={sk}")
            pk = chain(sk, 0, w - 1, PKseed, adrs_hex,
                       showdebug=(DEBUG and (chainaddr < 2 or chainaddr == 34)));
            DPRINT(f"pk={pk}")
            heads += pk

        DPRINT(f"Input to thash:\n{heads}")
        # for thash,
        wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=layer)
        wots_pk_adrs.setTreeAddress(tree_addr)
        wots_pk_adrs.setKeyPairAddress(this_leaf)
        wots_pk_addr_hex = wots_pk_adrs.toHex()
        DPRINT(f"wots_pk_addr={wots_pk_addr_hex}")
        leaf = T_len(PKseed, wots_pk_addr_hex, heads)
        DPRINT(f"leaf[{leaf}]={leaf}")
        leaves.append(leaf)

    DPRINT(leaves)

    # Compute the root node of Merkle tree using H
    # Start with 8 leaf values in array
    adrs = Adrs(Adrs.TREE, layer=layer)
    adrs.setTreeAddress(tree_addr)
    DPRINT(f"ADRS={adrs.toHex()}")
    root = hash_root(leaves, adrs, PKseed)
    DPRINT(f"root={root}")
    if TESTCASE0 and layer == 0: assert root == 'f2ec3b2ae23a50355d057b97df65c8bc'

    # Compute the authentication path from leaf_idx
    DPRINT(f"Computing authpath for leaf index {idx_leaf}...")
    auth = authpath(leaves, adrs, PKseed, idx_leaf)
    DPRINT("authpath:")
    [DPRINT(a) for a in auth]
    if TESTCASE0 and layer == 0: assert auth[2] == '77a2617d410d8f1acd1fbc29830e1a51'

    # Output this authpath to the signature value
    sig += auth[0] + format(f" # ht_auth_path[{layer}]\n")
    sig += "\n".join(auth[1:]) + "\n"

    # Set next wots_input to root and change tree_addr for next layer
    wots_input = root
    idx_leaf = tree_addr & 0x7  # (2^3 - 1)
    tree_addr >>= 3
    # Loop for next higher subtree...

# At the end the final root value MUST equal PK.root
print(f"Final root={wots_input}")
print(f"Expecting ={PKroot}")
assert wots_input.lower() == PKroot.lower()

# Print out and check the signature
print(f"sig:\n{sig}")
print("sig lines =", sig.count('\n'), "(expecting 1068)")
# Strip sig string down to pure hex digits
sighex = sig
sighex = re.sub(r'\s+\#.*?$', '', sighex, flags=re.MULTILINE)
sighex = sighex.replace('\n', '')
print("sighex len =", len(sighex), "(expecting 17088 x 2 = 34176)")
# Compute SHA-256 of signature
hash_sig = SHA256(sighex)
print(f"SHA256(sighex)={hash_sig}")
print(f"Expected      ={expected_hash}")