# @file for_authpath_01_32.py
# @version 1.1.0 (2026-02-15T08:23Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0

"""Compute the authpaths from scratch for the FORS trees with i=0..32
and the fors_pk_root value."""

from slh_adrs import Adrs
from slh_util import hash_root, authpath, fors_sk_gen
from slh_sha256 import PRF, F, T_len

DEBUG = True  # Set True for debugging output
DPRINT = print if DEBUG else lambda *a, **k: None

PKseed = 'FA495FB834DEFEA7CC96A81309479135'
SKseed = 'D5213BA4BB6470F1B9EDA88CBC94E627'
t = 64
k = 33

tree_address = 0x7cdcef4b8fdb03b0
leaf_address = 0
indices = [50, 47, 49, 35, 39, 21, 57, 21, 2, 0, 31, 56, 18, 58, 58, 31, 31, 43, 35, 26, 8, 7, 31, 13, 21, 57, 63, 20, 33, 5, 8, 32, 28]
roots = []

# DEBUG: just do first tree
ntrees = 1 # should be k
for i in range(ntrees): # range(k)
    # Compute all t FORS sk and pk values for tree i
    # Set up ADRS object
    adrs = Adrs(Adrs.FORS_TREE, layer=0)
    adrs.setTreeAddress(tree_address)
    adrs.setKeyPairAddress(leaf_address)
    leaves = []
    skeys = []

    for j in range(t):
        treeindex = i * t + j
        adrs.setTreeIndex(treeindex)
        DPRINT(f"ADRS={adrs.toHex(True)}")
        sk = fors_sk_gen(SKseed, PKseed, adrs)
        DPRINT(f"fors_sk[{i}][{j}]={sk}")
        pk = F(PKseed, adrs.toHex(True), sk)
        DPRINT(f"fors_pk[{i}][{j}]={pk}")
        leaves.append(pk)
        skeys.append(sk)

    #print("leaves=", leaves, sep='\n')
    #print("skeys=", skeys, sep='\n')

    # Compute the root value and the authpath for indices[i]
    idx = indices[i]
    print(f"fors_sk[{i}][{idx}]={skeys[idx]}")
    if i == 0: assert(skeys[idx] == "925bb207d49e62bcb9b1c4685154a8b3")
    print(f"fors_pk[{i}][{idx}]={leaves[idx]}")
    print(f"ADRS={adrs.toHexSP(True)}")
    root = hash_root(leaves, adrs, PKseed, i * t, showdebug=True)
    print(f"root[{i}]={root}")
    if i == 0: assert(root == "fb0f55d9717066cf3c666854d1e2f928")
    roots.append(root)

    auth = authpath(leaves, adrs, PKseed, idx, i * t, showdebug=False)
    print(f"fors_auth_path[{i}]:")
    [print(a) for a in auth]
    if i == 0: assert(auth[5] == 'c5cce74326d6181d01b74e3cd7f794a9')
    '''
    2e58b70c7aed0e28507f31b49ec7ed6e  # fors_auth_path[0]
    d6dcb8db2da90fe938994d75c80e6712
    f2421c22def8af88906b768333e7ebf6
    ddf7b84dc01f06731dd640cf93f57927
    bb56f9da9d4b2abe60c81d863a20f8e5
    c5cce74326d6181d01b74e3cd7f794a9
    '''

print("roots:", roots)

if ntrees == k:
    # Compute the FORS public key given the roots of the k FORS trees
    pkAdrs = Adrs(Adrs.FORS_ROOTS, layer=0)
    pkAdrs.setTreeAddress(tree_address)
    pkAdrs.setKeyPairAddress(leaf_address)
    print(f"ADRS={pkAdrs.toHexSP(True)}")
    fors_pk = T_len(PKseed, pkAdrs.toHex(True), "".join(roots))
    print(f"fors_pk={fors_pk}")
    assert(fors_pk == "33af163817cd6c2bea881ddf7d2b89ab")