# @file wots_gen_pk.py (2023-03-16T14:29Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0

"""Compute root node of the top-most HT subtree = PK.root."""
from spx_adrs import Adrs
from spx_sha256 import PRF, T_len
from spx_util import chain, hash_root

# Global vars
PKseed = 'B505D7CFAD1B497499323C8686325E47'
SKseed = '7C9935A0B07694AA0C6D10E4DB6B1ADD'
w = 16

# Compute leaves of top-most subtree at layer 21
leaves = [];
for leaf_idx in range(8):
    print(f"leaf_idx={leaf_idx}")
    adrs = Adrs(Adrs.WOTS_HASH, layer=0x15)
    adrs.setKeyPairAddress(leaf_idx)
    print(adrs.toHex())
    heads = ""  # concatenation of heads of WOTS+ chains
    for chainaddr in range(35):
        adrs.setChainAddress(chainaddr)
        adrs_hex = adrs.toHex()
        sk = PRF(SKseed, adrs_hex)
        print(f"sk[{chainaddr}]={sk}")
        pk = chain(sk, 0, w - 1, PKseed, adrs_hex,
                   showdebug=(chainaddr < 2 or chainaddr == 34));
        print(f"pk={pk}")
        if leaf_idx == 0 and chainaddr == 0:
            print(f"OK=f3a6275658f3be797af7022736613710")
        if leaf_idx == 0 and chainaddr == 34:
            print(f"OK=5962475912c4f1408d895c2de893f375")
        heads += pk

    print(f"Input to thash:\n{heads}")
    # for thash, 
    #wots_pk_addr = "15000000000000000001000000000000000000000000"
    wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=0x15)
    wots_pk_adrs.setKeyPairAddress(leaf_idx)
    wots_pk_addr_hex = wots_pk_adrs.toHex()
    print(f"wots_pk_addr={wots_pk_addr_hex}")
    leaf = T_len(PKseed, wots_pk_addr_hex, heads)
    print(f"leaf[{leaf_idx}]={leaf}")
    #print(f"OK  =TODO")
    leaves.append(leaf)

print(leaves)
assert(leaves[0] == "505df0061b7e0041c8501bc5030ad439")
assert(leaves[7] == "84035916aba8e0b92f73364d4bb50a18")


# Compute the root node of Merkle tree using H
# Start with 8 leaf values in array
hashes = leaves
# Top-most tree out of 22, layer=21=0x15
adrs = Adrs(Adrs.TREE, layer=0x15)

root = hash_root(hashes, adrs, PKseed)
print(f"root={root}")
print(f"OK  =4fdfa42840c84b1ddd0ea5ce46482020")