# @file wots_authpath_verify.py (2023-03-16T14:29Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0
"""Verify the authpath in the first WOTS subtree."""
from spx_adrs import Adrs
from spx_sha256 import H
# sk = PRF(SK.seed, ADRS)
# node = F(PK.seed, ADRS, sk)
PKseed = 'B505D7CFAD1B497499323C8686325E47'
# Required root value in WOTS subtree
root = "f2ec3b2ae23a50355d057b97df65c8bc"
authpath = [
"b1f67e538bb9d4c2ef860f50085bcb72",
"c10fb38ab696949b9417ddbefe8e4cad",
"77a2617d410d8f1acd1fbc29830e1a51",
]
print("authpath=", [x for x in authpath])
leaf= "56c1cd468c05d6b5a9ad57e87c4edf12"
print(f"leaf={leaf}")
# Compute leaf value above input
nlevels = 3
level = nlevels
idx = 6
height = 0
adrs = Adrs(Adrs.TREE, layer=0)
adrs.setTreeAddress(0x28daecdc86eb8761)
print(f"ADRS={adrs.toHex()}")
binstr = format(idx,f'0{level}b') # Binary representation of node
print(f"{binstr} idx={idx} level={level} height={height}")
node = leaf
print(f"{binstr} node={node}")
for i in range(nlevels):
# Get last bit of child idx
lastbit = idx & 1
# Move up to next level
idx >>= 1
height += 1
level -= 1
binstr = format(idx,f'0{level}b')
print(f"{binstr} idx={idx} level={level} height={height}")
adrs.setTreeHeight(height)
adrs.setTreeIndex(idx)
print(f"ADRS={adrs.toHex()}")
print(f"authpath={authpath[i]}")
# Last bit of child determines left/right order of authpath
if (lastbit):
print(f"m1,m2={authpath[i]} {node}")
node = H(PKseed, adrs.toHex(), authpath[i], node)
else:
print(f"m1,m2={node} {authpath[i]}")
node = H(PKseed, adrs.toHex(), node, authpath[i])
print(f"{binstr} node={node}")
print(f"Required root={root}")
assert(root == node)