# @file wots_authpath_verify.py (2023-03-16T14:29Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0

"""Verify the authpath in the first WOTS subtree."""

from spx_adrs import Adrs
from spx_sha256 import H

# sk = PRF(SK.seed, ADRS)
# node = F(PK.seed, ADRS, sk)


PKseed = 'B505D7CFAD1B497499323C8686325E47'
# Required root value in WOTS subtree
root = "f2ec3b2ae23a50355d057b97df65c8bc"

authpath = [
"b1f67e538bb9d4c2ef860f50085bcb72",
"c10fb38ab696949b9417ddbefe8e4cad",
"77a2617d410d8f1acd1fbc29830e1a51",
]

print("authpath=", [x for x in authpath])

leaf= "56c1cd468c05d6b5a9ad57e87c4edf12" 
print(f"leaf={leaf}")

# Compute leaf value above input
nlevels = 3
level = nlevels
idx = 6
height = 0
adrs = Adrs(Adrs.TREE, layer=0)
adrs.setTreeAddress(0x28daecdc86eb8761)
print(f"ADRS={adrs.toHex()}")
binstr = format(idx,f'0{level}b')   # Binary representation of node
print(f"{binstr} idx={idx} level={level} height={height}")
node = leaf
print(f"{binstr} node={node}")

for i in range(nlevels):
    # Get last bit of child idx
    lastbit = idx & 1
    # Move up to next level
    idx >>= 1
    height += 1
    level -= 1
    binstr = format(idx,f'0{level}b')
    print(f"{binstr} idx={idx} level={level} height={height}")
    adrs.setTreeHeight(height)
    adrs.setTreeIndex(idx)
    print(f"ADRS={adrs.toHex()}")
    print(f"authpath={authpath[i]}")
    # Last bit of child determines left/right order of authpath
    if (lastbit):
        print(f"m1,m2={authpath[i]} {node}") 
        node = H(PKseed, adrs.toHex(), authpath[i], node)
    else:
        print(f"m1,m2={node} {authpath[i]}") 
        node = H(PKseed, adrs.toHex(), node, authpath[i])
    print(f"{binstr} node={node}")

print(f"Required root={root}")
assert(root == node)