# @file slh_tree.py
# @version 1.1.0 (2026-02-15T08:23Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0
"""SLH-DSA utilities for tree calculations."""
import collections
from slh_params import params
from slh_adrs import Adrs
import slh_hashfuncs as hashfuncs
def chain(X, i, s, PKseed, adrs_hex, params, showdebug=False):
"""chain unrolled"""
# adrs is in hex, get object
o = Adrs.fromHex(adrs_hex)
for hashaddr in range(i, s):
# print(f"hashaddr={hashaddr}")
adrs_hex = o.setHashAddress(hashaddr).toHex(params.compr)
if showdebug: print(f"adrs={adrs_hex}")
X = params.F(PKseed, adrs_hex, X, params)
if showdebug: print(f"F({hashaddr})=", X)
return X
def hash_pairwise(hashes, adrs, PKseed, params, startidx=0, showdebug=False):
"""Hash 2n hex values pairwise.
IN:
hashes: Array of 2*n node values to be hashed pairwise using H
adrs: ADRS object set for layer and treeaddress with
either type=TREE or FORS_TREE.
startidx: start index = index of leftmost first parent node
OUT: Array of n parent node values H(node_{2i} || node_{2i+1})
for 0 <= i < n.
PRE: len(hashes) is a power of 2
"""
n = len(hashes) // 2
out = ['' for _ in range(n)]
for i in range(n):
adrs.setTreeIndex(startidx + i)
if showdebug: print(f"ADRS={adrs.toHex(params.compr)}")
h = params.H(PKseed, adrs.toHex(params.compr), hashes[2 * i], hashes[2 * i + 1], params)
out[i] = h
return out
def hash_root(hashes, adrs, PKseed, params, startidx=0, showdebug=False):
"""
Compute root of 2^z leaf values.
hashes: Array of leaf values to be hashed pairwise using H,
adrs: ADRS object set for layer and treeaddress with
either type=TREE or FORS_TREE.
startidx: start index = index of leftmost leaf
PRE: len(hashes) is a power of 2; startidx mod 2^z = 0
"""
# Leaves are at tree height 0
treeht = 0
while (len(hashes) > 1):
treeht += 1
adrs.setTreeHeight(treeht)
startidx //= 2
hashes = hash_pairwise(hashes, adrs, PKseed, params, startidx, showdebug)
if showdebug: print(hashes)
return hashes[0]
def authpath(leaves, adrs, PKseed, leaf_idx, params, startidx=0, showdebug=False):
"""Compute authpath array.
Args:
leaves: Array of 2^z leaf values at treeHeight 0.
adrs: ADRS object set for layer and treeaddress with
either type=TREE or FORS_TREE.
PKseed: PK.seed value in hex
leaf_idx: 0-based index of leaf to be authenticated.
params: params imported from spx_params
startidx: start index = index of leftmost leaf.
showdebug: display debugging output.
Returns:
Array containing authentication path as hex values.
PRE:
len(leaves) is a power of 2; startidx mod 2^z = 0.
"""
auth = []
treeht = 0
i = leaf_idx
while (len(leaves) > 1):
# Get hash value we want at current level
y = i ^ 1
auth.append(leaves[y])
treeht += 1
i //= 2
startidx //= 2
adrs.setTreeHeight(treeht)
leaves = hash_pairwise(leaves, adrs, PKseed, params, startidx, showdebug)
return auth
if __name__ == '__main__':
# Just for testing - we only need H
params_test = params(name='SHA2-128f', n=16, h=66, d=22, a=6, k=33, lgw=4, m=34, compr=True,
H_msg=None, PRF=None, PRF_msg=None, F=None, H=hashfuncs.H_sha256, T_len=None)
print(params_test.name)
# Compute the root node of Merkle tree using H
PKseed = 'B505D7CFAD1B497499323C8686325E47'
# Start with 8 leaf values in array
leaves = ['505df0061b7e0041c8501bc5030ad439',
'7bd5deb67217d33505043e204d88f687',
'03b03bb327c9b48beab7722c4d5eb906',
'fa1ef7c928518b1afdebddd1b83a3b66',
'44b4dad150fdf64b6aa7fab1aea016e6',
'0913211acf332a24629915d1b8226ff2',
'a8fca106e9c1263dda280988f59f13e2',
'84035916aba8e0b92f73364d4bb50a18']
# Top-most tree out of 22, layer=21
adrs = Adrs(Adrs.TREE, layer=0x15)
root = hash_root(leaves, adrs, PKseed, params_test)
print(f"root={root}")
print(f"OK =4fdfa42840c84b1ddd0ea5ce46482020")
params_test = params(name='SHAKE-256s', n=32, h=64, d=8, a=14, k=22, lgw=4, m=34, compr=False,
H_msg=None, PRF=None, PRF_msg=None, F=None, H=hashfuncs.H_shake, T_len=None)
print(params_test.name)
# Compute the root node of Merkle tree using H
PKseed = '3E784CCB7EBCDCFD45542B7F6AF778742E0F4479175084AA488B3B74340678AA'
# Start with 8 leaf values in array
leaves = ['505df0061b7e0041c8501bc5030ad439',
'7bd5deb67217d33505043e204d88f687',
'03b03bb327c9b48beab7722c4d5eb906',
'fa1ef7c928518b1afdebddd1b83a3b66',
'44b4dad150fdf64b6aa7fab1aea016e6',
'0913211acf332a24629915d1b8226ff2',
'a8fca106e9c1263dda280988f59f13e2',
'84035916aba8e0b92f73364d4bb50a18']
# Top-most tree out of 8, layer=7
adrs = Adrs(Adrs.TREE, layer=7)
root = hash_root(leaves, adrs, PKseed, params_test)
print(f"root={root}")
print(f"OK =bd1970d9000632287b4b3988c66238328e307d023b9971685c369d53af4b51ad")