# @file merkle_tree_4byte_pairwise.py
# @version 1.1.0 (2026-02-15T08:23Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0

"""Generate a Merkle tree using weak 4-byte hash function
and compute AUTH using hash_pairwise algorithm."""

import hashlib

# DEBUGGING
DEBUG = False
DPRINT = print if DEBUG else lambda *a, **k: None

def H(val: str) -> str:
    """Weak hash function that returns a 4-byte hash.

    :param str val: Hex-encoded input.
    :return: Four-byte hash of input encoded in hex.
    :rtype: str
    """
    return hashlib.sha256(bytes.fromhex(val)).hexdigest()[:8]

def is_power_of_two(n: int) -> bool:
    return n > 0 and (n & (n -1)) == 0

def hash_pairwise(hashes):
    """Hash 2n hex values pairwise.

    :param list hashes: Array of 2*n hex strings to be hashed pairwise.
    :return: Array of n hex-encoded hash strings H(node_{2i} || node_{2i+1})
             for 0 <= i < n.
    :rtype: list
    """
    n = len(hashes) // 2
    out = ['' for x in range(n)]
    for i in range(n):
        h = H(hashes[2 * i] + hashes[2 * i + 1])
        out[i] = h
    return out

def make_sk(seed, height, index):
    """Generate a private key value given a secret key and address.

    :param str seed: Secret key seed (hex string).
    :param int height: Height value to encode into address.
    :param int index: Index value to encode into address.
    :return: Derived private key value (4-byte hash as hex string).
    :rtype: str
    """
    adrs = "{:04x}".format(height) + "{:04x}".format(index)
    DPRINT("adrs=", adrs, bytes.fromhex(adrs))
    return H(seed + adrs)
  

# Generate all keys at leaf level 0
sk_seed = '900150983cd24fb0'  # The first 8 bytes of MD5('abc')
keys = [make_sk(sk_seed, 0, i) for i in range(8)]
print("secret keys =", keys)
print(f"size={len(keys)}")
assert(is_power_of_two(len(keys)))

# Hash the keys to leaf nodes
hashes = [H(k) for k in keys]
print("public keys =", hashes)

# Compute AUTHPATH given public keys in `hashes` 
a = 3
idx = 6
print(f"leaf_index={idx}")
print(f"sk[{idx}]={keys[idx]}")
authpath = ['' for x in range(a)]
i = idx ^ 1
j = 0
authpath[j] = hashes[i]
for j in range(1, a):
    hashes = hash_pairwise(hashes)
    i = (i // 2) ^ 1
    authpath[j] = hashes[i]
print("authpath =", authpath)
# authpath = ['8fc53ad8', '7890c8b6', '076f83f6']