# @file wots_ht_0_authpath.py
# @version 1.1 (2026-01-16T14:29Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0

"""Compute root node and authpath for bottom HT subtree at layer 0."""

from slh_adrs import Adrs
from slh_sha256 import F, PRF, T_len
from slh_util import chain, hash_root, authpath

# Global vars
PKseed = 'FA495FB834DEFEA7CC96A81309479135'
SKseed = 'D5213BA4BB6470F1B9EDA88CBC94E627'
w = 16

# Compute all leaves of bottom subtree at layer 0
leaves = []
layer = 0
tree_addr = 0x7cdcef4b8fdb03b0
leaf_idx = 0  # Only needed for authpath
wots_len = 35
nleaves = 8

for this_leaf in range(nleaves):
    print(f"this_leaf={this_leaf}")
    adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(this_leaf)
    print(adrs.toHexSP(True))
    skAdrs = adrs.copy()
    skAdrs.setType(Adrs.WOTS_PRF)
    skAdrs.setKeyPairAddress(adrs.getKeyPairAddress())
    heads = ""  # concatenation of heads of WOTS+ chains
    for chainaddr in range(wots_len):
        # [v3.1] Use WOTS_PRF to create sk, but WOTS_HASH for pk
        skAdrs.setChainAddress(chainaddr)
        # Compute secret value for chain i
        sk = PRF(PKseed, SKseed, skAdrs.toHex(True))
        print(f"sk[{chainaddr}]={sk}")
        # Compute public value for chain i
        adrs.setChainAddress(chainaddr)
        pk = chain(sk, 0, w - 1, PKseed, adrs.toHex(True),
                    showdebug=((chainaddr < 2 or chainaddr == wots_len - 1)))
        print(f"pk={pk}")
        heads += pk

    print(f"Input to thash:\n{heads}")
    # for thash,
    wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=layer)
    wots_pk_adrs.setTreeAddress(tree_addr)
    wots_pk_adrs.setKeyPairAddress(this_leaf)
    print(f"wots_pk_addr={wots_pk_adrs.toHexSP(True)}")
    # Compress public key
    leaf = T_len(PKseed, wots_pk_adrs.toHex(True), heads)
    print(f"leaf[{this_leaf}]={leaf}")
    leaves.append(leaf)

print(leaves)

# Compute the root node of Merkle tree using H
# Start with 8 leaf values in array
adrs = Adrs(Adrs.TREE, layer=layer)
adrs.setTreeAddress(tree_addr)
print(f"ADRS={adrs.toHex(True)}")
root = hash_root(leaves, adrs, PKseed)
print(f"root={root}")
print(f"OK  =af0daeccc5501e78851bf9a7896945b5")
assert root == 'af0daeccc5501e78851bf9a7896945b5'

# Compute the authentication path from leaf_idx
print(f"Computing authpath for leaf index {leaf_idx}...")
auth = authpath(leaves, adrs, PKseed, leaf_idx)
print("authpath:")
[print(a) for a in auth]
assert auth[2] == '88378feae0988ef4202eea1238ada111'