# @file slh_tree.py 
# @version 1.1.0 (2026-02-15T08:23Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0

"""SLH-DSA utilities for tree calculations."""

import collections
from slh_params import params
from slh_adrs import Adrs
import slh_hashfuncs as hashfuncs


def chain(X, i, s, PKseed, adrs_hex, params, showdebug=False):
    """chain unrolled"""
    # adrs is in hex, get object
    o = Adrs.fromHex(adrs_hex)
    for hashaddr in range(i, s):
        # print(f"hashaddr={hashaddr}")
        adrs_hex = o.setHashAddress(hashaddr).toHex(params.compr)
        if showdebug: print(f"adrs={adrs_hex}")
        X = params.F(PKseed, adrs_hex, X, params)
        if showdebug: print(f"F({hashaddr})=", X)
    return X


def hash_pairwise(hashes, adrs, PKseed, params, startidx=0, showdebug=False):
    """Hash 2n hex values pairwise.
    IN:
    hashes: Array of 2*n node values to be hashed pairwise using H
    adrs: ADRS object set for layer and treeaddress with
          either type=TREE or FORS_TREE.
    startidx: start index = index of leftmost first parent node
    OUT: Array of n parent node values H(node_{2i} || node_{2i+1})
    for 0 <= i < n.
    PRE: len(hashes) is a power of 2
    """
    n = len(hashes) // 2
    out = ['' for _ in range(n)]
    for i in range(n):
        adrs.setTreeIndex(startidx + i)
        if showdebug: print(f"ADRS={adrs.toHex(params.compr)}")
        h = params.H(PKseed, adrs.toHex(params.compr), hashes[2 * i], hashes[2 * i + 1], params)
        out[i] = h
    return out


def hash_root(hashes, adrs, PKseed, params, startidx=0, showdebug=False):
    """
    Compute root of 2^z leaf values.
    hashes: Array of leaf values to be hashed pairwise using H,
    adrs: ADRS object set for layer and treeaddress with
          either type=TREE or FORS_TREE.
    startidx: start index = index of leftmost leaf
    PRE: len(hashes) is a power of 2; startidx mod 2^z = 0
    """
    # Leaves are at tree height 0
    treeht = 0
    while (len(hashes) > 1):
        treeht += 1
        adrs.setTreeHeight(treeht)
        startidx //= 2
        hashes = hash_pairwise(hashes, adrs, PKseed, params, startidx, showdebug)
        if showdebug: print(hashes)
    return hashes[0]


def authpath(leaves, adrs, PKseed, leaf_idx, params, startidx=0, showdebug=False):
    """Compute authpath array.
    Args:
        leaves: Array of 2^z leaf values at treeHeight 0.
        adrs: ADRS object set for layer and treeaddress with
              either type=TREE or FORS_TREE.
        PKseed: PK.seed value in hex
        leaf_idx: 0-based index of leaf to be authenticated.
        params: params imported from spx_params
        startidx: start index = index of leftmost leaf.
        showdebug: display debugging output.
    Returns:
        Array containing authentication path as hex values.
    PRE:
        len(leaves) is a power of 2; startidx mod 2^z = 0.
    """
    auth = []
    treeht = 0
    i = leaf_idx
    while (len(leaves) > 1):
        # Get hash value we want at current level
        y = i ^ 1
        auth.append(leaves[y])
        treeht += 1
        i //= 2
        startidx //= 2
        adrs.setTreeHeight(treeht)
        leaves = hash_pairwise(leaves, adrs, PKseed, params, startidx, showdebug)
    return auth


if __name__ == '__main__':

    # Just for testing - we only need H
    params_test = params(name='SHA2-128f', n=16, h=66, d=22, a=6, k=33, lgw=4, m=34, compr=True,
                         H_msg=None, PRF=None, PRF_msg=None, F=None, H=hashfuncs.H_sha256, T_len=None)
    print(params_test.name)
    # Compute the root node of Merkle tree using H
    PKseed = 'B505D7CFAD1B497499323C8686325E47'
    # Start with 8 leaf values in array
    leaves = ['505df0061b7e0041c8501bc5030ad439',
              '7bd5deb67217d33505043e204d88f687',
              '03b03bb327c9b48beab7722c4d5eb906',
              'fa1ef7c928518b1afdebddd1b83a3b66',
              '44b4dad150fdf64b6aa7fab1aea016e6',
              '0913211acf332a24629915d1b8226ff2',
              'a8fca106e9c1263dda280988f59f13e2',
              '84035916aba8e0b92f73364d4bb50a18']
    # Top-most tree out of 22, layer=21
    adrs = Adrs(Adrs.TREE, layer=0x15)
    root = hash_root(leaves, adrs, PKseed, params_test)
    print(f"root={root}")
    print(f"OK  =4fdfa42840c84b1ddd0ea5ce46482020")    

    params_test = params(name='SHAKE-256s', n=32, h=64, d=8, a=14, k=22, lgw=4, m=34, compr=False,
                         H_msg=None, PRF=None, PRF_msg=None, F=None, H=hashfuncs.H_shake, T_len=None)
    print(params_test.name)
    # Compute the root node of Merkle tree using H
    PKseed = '3E784CCB7EBCDCFD45542B7F6AF778742E0F4479175084AA488B3B74340678AA'
    # Start with 8 leaf values in array
    leaves = ['505df0061b7e0041c8501bc5030ad439',
              '7bd5deb67217d33505043e204d88f687',
              '03b03bb327c9b48beab7722c4d5eb906',
              'fa1ef7c928518b1afdebddd1b83a3b66',
              '44b4dad150fdf64b6aa7fab1aea016e6',
              '0913211acf332a24629915d1b8226ff2',
              'a8fca106e9c1263dda280988f59f13e2',
              '84035916aba8e0b92f73364d4bb50a18']
    # Top-most tree out of 8, layer=7
    adrs = Adrs(Adrs.TREE, layer=7)
    root = hash_root(leaves, adrs, PKseed, params_test)
    print(f"root={root}")
    print(f"OK  =bd1970d9000632287b4b3988c66238328e307d023b9971685c369d53af4b51ad")