# @file fors_authpath_verify.py (2023-03-16T14:29Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0

"""Verify the first FORS authpath.""" 

from spx_sha256 import F, H
from spx_adrs import Adrs

# sk = PRF(SK.seed, ADRS)
# node = F(PK.seed, ADRS, sk)


PKseed = 'B505D7CFAD1B497499323C8686325E47'
# Required root value in HORS tree
root = "8f5a82fe31bc814bdf198d01481651c5"

authpath = [
"90d9d26cf0068d14f2125ffa16dce594",
"3af75452a07b7bc67344a77fba2bc51f",
"1ee71cab80e5d588a4f3e181cfca703b",
"e987519c0578e6edc2cac80c3d0f5781",
"0d7dabe142fb0976638fe21d503dabde",
"606c9bbca50ee8112dcc0cea78e81501",
]

print("authpath=", [x for x in authpath])

# This sk value has been revealed in signature, so we know it
sk = "8c9f8091d1a1edbb6a8a041343c6e5c0" # fors_sig_sk[0]
print(f"sk={sk}")

# Base ADRS excluding last two 4-byte words
adrs = Adrs.fromHex('0028daecdc86eb876103000000060000000000000000')

# Compute leaf value above sk
idx = 27
level = 6
height = 0
binstr = format(idx,f'0{level}b')   # Binary representation of node
print(f"{binstr} level={level} height={height}")
adrs.setTreeHeight(height)
adrs.setTreeIndex(idx)
adrs_c = adrs.toHex()
print(f"ADRS={adrs_c}")
node = F(PKseed, adrs_c, sk)
print(f"{binstr} node={node}")

for i in range(level):
    # Get last bit of child idx
    lastbit = idx & 1
    # Move up to next level
    idx >>= 1
    height += 1
    level -= 1
    binstr = format(idx,f'0{level}b')
    print(f"{binstr} level={level} height={height}")
    adrs.setTreeHeight(height)
    adrs.setTreeIndex(idx)
    adrs_c = adrs.toHex()
    print(f"ADRS={adrs_c}")
    print(f"authpath={authpath[i]}")
    # Last bit of child determines left/right order of authpath
    if (lastbit):
        print(f"m1,m2={authpath[i]} {node}") 
        node = H(PKseed, adrs_c, authpath[i], node)
    else:
        print(f"m1,m2={node} {authpath[i]}") 
        node = H(PKseed, adrs_c, node, authpath[i])
    print(f"{binstr} node={node}")

print(f"Required root={root}")
assert(root == node)