# @file spx_utils.py (2023-03-16T14:29Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0

"""SPHINCS+ utilities for tree calculations."""

from spx_adrs import Adrs
from spx_sha256 import F, H

def chain(X, i, s, PKseed, adrs_hex, showdebug=False):
    """chain unrolled"""
    # adrs is in hex, get object
    o = Adrs.fromHex(adrs_hex)
    for hashaddr in range(i, s):
        #print(f"hashaddr={hashaddr}")
        adrs_hex = o.setHashAddress(hashaddr).toHex()
        if showdebug: print(f"adrs={adrs_hex}")
        X = F(PKseed, adrs_hex, X)
        if showdebug: print(f"F({hashaddr})=", X)
    return X


def hash_pairwise(hashes, adrs, PKseed, startidx=0, showdebug=False):
    """Hash 2n hex values pairwise.
    IN:
    hashes: Array of 2*n node values to be hashed pairwise using H
    adrs: ADRS object set for layer and treeaddress with
          either type=TREE or FORS_TREE.
    startidx: start index = index of leftmost first parent node
    OUT: Array of n parent node values H(node_{2i} || node_{2i+1})
    for 0 <= i < n.
    PRE: len(hashes) is a power of 2
    """
    n = len(hashes) // 2
    out = ['' for x in range(n)]
    for i in range(n):
        adrs.setTreeIndex(startidx + i)
        if showdebug: print(f"ADRS={adrs.toHex()}")
        h = H(PKseed, adrs.toHex(), hashes[2 * i], hashes[2 * i + 1])
        out[i] = h
    return out


def hash_root(hashes, adrs, PKseed, startidx=0, showdebug=False):
    """
    Compute root of 2^z leaf values.
    hashes: Array of leaf values to be hashed pairwise using H,
    adrs: ADRS object set for layer and treeaddress with
          either type=TREE or FORS_TREE.
    startidx: start index = index of leftmost leaf
    PRE: len(hashes) is a power of 2; startidx mod 2^z = 0
    """
    # Leaves are at tree height 0
    treeht = 0
    while (len(hashes) > 1):
        treeht += 1
        adrs.setTreeHeight(treeht)
        startidx //= 2
        hashes = hash_pairwise(hashes, adrs, PKseed, startidx, showdebug)
        if showdebug: print(hashes)
    return hashes[0]


def authpath(leaves, adrs, PKseed, leaf_idx, startidx=0, showdebug=False):
    """Compute authpath array.
    Args:
        leaves: Array of 2^z leaf values at treeHeight 0.
        adrs: ADRS object set for layer and treeaddress with
              either type=TREE or FORS_TREE.
        PKseed: PK.seed value in hex
        leaf_idx: 0-based index of leaf to be authenticated.
        startidx: start index = index of leftmost leaf
    Returns:
        Array containing authentication path as hex values.
    PRE:
        len(leaves) is a power of 2; startidx mod 2^z = 0.
    """
    auth = []
    treeht = 0
    i = leaf_idx
    while (len(leaves) > 1):
        # Get hash value we want at current level
        y = i ^ 1
        auth.append(leaves[y])
        treeht += 1
        i //= 2
        startidx //= 2
        adrs.setTreeHeight(treeht)
        leaves = hash_pairwise(leaves, adrs, PKseed, startidx, showdebug)
    return auth


if __name__ == '__main__':
    # Compute the root node of Merkle tree using H
    PKseed = 'B505D7CFAD1B497499323C8686325E47'
    # Start with 8 leaf values in array
    leaves = ['505df0061b7e0041c8501bc5030ad439',
              '7bd5deb67217d33505043e204d88f687',
              '03b03bb327c9b48beab7722c4d5eb906',
              'fa1ef7c928518b1afdebddd1b83a3b66',
              '44b4dad150fdf64b6aa7fab1aea016e6',
              '0913211acf332a24629915d1b8226ff2',
              'a8fca106e9c1263dda280988f59f13e2',
              '84035916aba8e0b92f73364d4bb50a18']
    # Top-most tree out of 22, layer=21
    adrs = Adrs(Adrs.TREE, layer=0x15)
    root = hash_root(leaves, adrs, PKseed)
    print(f"root={root}")
    print(f"OK  =4fdfa42840c84b1ddd0ea5ce46482020")