# @file wots_ht_0.py (2023-03-16T14:29Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0

"""Compute root node and authpath for bottom HT subtree at layer 0."""

from spx_adrs import Adrs
from spx_sha256 import PRF, T_len
from spx_util import chain, hash_root, authpath

# Global vars
PKseed = 'B505D7CFAD1B497499323C8686325E47'
SKseed = '7C9935A0B07694AA0C6D10E4DB6B1ADD'
w = 16

# Compute all leaves of bottom subtree at layer 0
leaves = [];
layer = 0
tree_addr = 0x28daecdc86eb8761
leaf_idx = 6  # Only needed for authpath

for this_leaf in range(8):
    print(f"this_leaf={this_leaf}")
    adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(this_leaf)
    print(adrs.toHex())
    heads = ""  # concatenation of heads of WOTS+ chains
    for chainaddr in range(35):
        adrs.setChainAddress(chainaddr)
        adrs_hex = adrs.toHex()
        sk = PRF(SKseed, adrs_hex)
        print(f"sk[{chainaddr}]={sk}")
        pk = chain(sk, 0, w - 1, PKseed, adrs_hex,
                   showdebug=(chainaddr < 2 or chainaddr == 34));
        print(f"pk={pk}")
        heads += pk

    print(f"Input to thash:\n{heads}")
    # for thash, 
    wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=layer)
    wots_pk_adrs.setTreeAddress(tree_addr)
    wots_pk_adrs.setKeyPairAddress(this_leaf)
    wots_pk_addr_hex = wots_pk_adrs.toHex()
    print(f"wots_pk_addr={wots_pk_addr_hex}")
    leaf = T_len(PKseed, wots_pk_addr_hex, heads)
    print(f"leaf[{leaf}]={leaf}")
    leaves.append(leaf)

print(leaves)


# Compute the root node of Merkle tree using H
# Start with 8 leaf values in array
adrs = Adrs(Adrs.TREE, layer=layer)
adrs.setTreeAddress(tree_addr)
print("ADRS={adrs.toHex()}")
root = hash_root(leaves, adrs, PKseed)
print(f"root={root}")
print(f"OK  =f2ec3b2ae23a50355d057b97df65c8bc")
assert root == 'f2ec3b2ae23a50355d057b97df65c8bc'

# Compute the authentication path from leaf_idx
print(f"Computing authpath for leaf index {leaf_idx}...")
auth = authpath(leaves, adrs, PKseed, leaf_idx)
print("authpath:")
[print(a) for a in auth]
assert auth[2] == '77a2617d410d8f1acd1fbc29830e1a51'