# @file wots_PKgenRoot.py (2023-03-16T14:29Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0

"""Generate WOTS public key root value."""

from spx_sha256 import F, PRF

def setADRS(adrs_base, chain, index, iset):
    t = 64
    treeindex = iset * t + index
    return adrs_base + format(chain, f'08x') + format(treeindex, f'08x')

w = 16
# ADRS excluding last two 4-byte words
# Type 0 WOTS+ hash address with keypair address = 0
adrs_base = '1500000000000000000000000000'
PKseed = 'B505D7CFAD1B497499323C8686325E47'
SKseed = '7c9935a0b07694aa0c6d10e4db6b1add'

print("Generate WOTS+ private key for i = 0")
# sk = PRF(SK.seed, ADRS)
adrs_c = setADRS(adrs_base, 0, 0, 0)
print(f"ADRS={adrs_c}")
sk = PRF(SKseed, adrs_c)
print(f"sk={sk}")
#print(f"OK=c04623124dfcdcb1de0ad8cfc68ebf73")

# Compute F^w(sk)
# NB We only do this (w-1) times!
x = sk
for i in range(w-1):
    adrs_c = setADRS(adrs_base, 0, i, 0)
    print(f"i={i} ADRS={adrs_c}")
    print(f"in={x}")
    x = F(PKseed, adrs_c, x)
    print(f"F(PK.seed, ADRS, in)={x}")

print(f"wots_pk:{x}")
print(f"OK:     f3a6275658f3be797af7022736613710")

wots_tmp = ""
for chain in range(0,35):
    print(f"chain={chain}")
    adrs_c = setADRS(adrs_base, chain, 0, 0)
    print(f"ADRS={adrs_c}")
    sk = PRF(SKseed, adrs_c)
    print(f"sk={sk}")

    # Compute F^w(sk)
    # NB We only do this (w-1) times!
    x = sk
    for i in range(w-1):
        adrs_c = setADRS(adrs_base, chain, i, 0)
        x = F(PKseed, adrs_c, x)
    print(f"wots_pk:{x}")
    wots_tmp += x

print("wots_tmp before T_len:\n" + wots_tmp)
# 15 0000000000000000 01 000000000000000000000000
adrs_c = "15000000000000000001000000000000000000000000"
leaf = F(PKseed, adrs_c, wots_tmp)
print(f"leaf={leaf}")
assert(leaf == '505df0061b7e0041c8501bc5030ad439')