# @file merkle_tree_4byte_treehash.py
# @version 1.1.0 (2026-02-15T08:23Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0

"""Generate a Merkle tree using weak 4-byte hash function
and compute AUTH using TreeHash algorithm."""

import hashlib

# DEBUGGING
DEBUG = False
DPRINT = print if DEBUG else lambda *a, **k: None

def H(val: str) -> str:
    """Weak hash function that returns a 4-byte hash.

    :param str val: Hex-encoded input.
    :return: Four-byte hash of input encoded in hex.
    :rtype: str
    """
    return hashlib.sha256(bytes.fromhex(val)).hexdigest()[:8]

def is_power_of_two(n: int) -> bool:
    return n > 0 and (n & (n -1)) == 0

def hash_pairwise(hashes):
    """Hash 2n hex values pairwise.

    :param list hashes: Array of 2*n hex strings to be hashed pairwise.
    :return: Array of n hex-encoded hash strings H(node_{2i} || node_{2i+1})
             for 0 <= i < n.
    :rtype: list
    """
    n = len(hashes) // 2
    out = ['' for x in range(n)]
    for i in range(n):
        h = H(hashes[2 * i] + hashes[2 * i + 1])
        out[i] = h
    return out

def make_sk(seed, height, index):
    """Generate a private key value given a secret key and address.

    :param str seed: Secret key seed (hex string).
    :param int height: Height value to encode into address.
    :param int index: Index value to encode into address.
    :return: Derived private key value (4-byte hash as hex string).
    :rtype: str
    """
    adrs = "{:04x}".format(height) + "{:04x}".format(index)
    DPRINT("adrs=", adrs, bytes.fromhex(adrs))
    return H(seed + adrs)
  

# Generate all keys at leaf level 0
sk_seed = '900150983cd24fb0'  # The first 8 bytes of MD5('abc')
keys = [make_sk(sk_seed, 0, i) for i in range(8)]
print("keys =", keys)
print(f"size={len(keys)}")
assert(is_power_of_two(len(keys)))

# Hash the keys to leaf nodes
hashes = [H(k) for k in keys]
print("leaf nodes=", hashes)

# Compute the Merkle tree nodes
while (len(hashes) > 1):
    hashes = hash_pairwise(hashes)
    print(hashes)

print(f"Root node is {hashes[0]}")
print('-'*16)

# Compute AUTHPATH
def compute_node(sk_seed, i, z):
    """Compute a node in the Merkle tree.

    :param str sk_seed: Secret key seed (hex string).
    :param int i: Index of the node.
    :param int z: Height of the node.
    :return: Hash value of the node.
    :rtype: str
    """
    node = None
    if z == 0:
        sk = make_sk(sk_seed, 0, i)
        node = H(sk)
        DPRINT(f"node: i = {i} sk={sk} node={node}")
    else:
        lnode = compute_node(sk_seed, 2*i, z-1)
        rnode = compute_node(sk_seed, 2*i + 1, z-1)
        node = H(lnode + rnode)
        DPRINT(f"node: ({z},{i}) {lnode}||{rnode} = {node}")
    return node

def gen_auth_path(sk_seed, a, idx):
    """Generate authentication path (AUTH) for a leaf.

    :param str sk_seed: Secret key seed (hex string).
    :param int a: Height of the AUTH path.
    :param int idx: Leaf index for which to generate the AUTH.
    :return: List of node hashes (hex strings) forming the AUTH path.
    :rtype: list
    """
    auth = [None] * a
    for j in range(a):
        s = idx // (1 << j) ^ 1
        DPRINT(f"j={j} s={s}")
        auth[j] = compute_node(sk_seed, s, j)
        DPRINT(f"auth[{j}]={auth[j]}")
    return auth
        
a = 3
idx = 6
print(f"leaf_index={idx}")
print(f"sk[{idx}]={keys[idx]}")
auth = gen_auth_path(sk_seed, a, idx)
print("auth =", auth)
# auth = ['8fc53ad8', '7890c8b6', '076f83f6']

# Compute root node
DPRINT("Computing root...")
root = compute_node(sk_seed, 0, a)
print(f"root={root}")
assert(root == '03583268')