# @file spx_makesig.py (2023-03-16T14:57Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0
"""Do it all to make a SPHINCS+ signature in one go."""
from spx_sha256 import *
from spx_adrs import Adrs
from spx_util import *
import re # For manipulation of sig string
"""
Reproduce SPHINCS+-SHA2-128f-simple test cases 0 and 1 from
NIST-PQ-Submission-SPHINCS-20201001/KAT/sphincs-sha256-128f-simple/PQCsignKAT_64.rsp
CAUTION: Many hard-coded constants! And hard-coded assert statements.
Lots of debugging statements. Set DEBUG=False to turn off.
NOTE: the value of `optrand` can only be found by running the implementation code.
"""
# DEBUGGING
DEBUG = False
DPRINT = print if DEBUG else lambda *a, **k: None
##################
# SELECT TEST CASE
##################
# Edit this...
test_case_to_do = 0
if test_case_to_do == 0:
# INPUT FOR TEST CASE 0
test_title = 'TEST CASE 0'
PKseed = 'B505D7CFAD1B497499323C8686325E47'
PKroot = '4FDFA42840C84B1DDD0EA5CE46482020'
SKseed = '7C9935A0B07694AA0C6D10E4DB6B1ADD'
SK_prf = '2fd81a25ccb148032dcd739936737f2d'
msg = \
'D81C4D8D734FCBFBEADE3D3F8A039FAA2A2C9957E835AD55B22E75BF57BB556AC8'
optrand = '33b3c07507e4201748494d832b6ee2a6'
# Expected hash of signature
expected_hash = 'ea2bef5299332943d7301a883aa6c1caba08975b7924ed581709b5b1c88beaad'
TESTCASE0 = True
elif test_case_to_do == 1:
# INPUT FOR TEST CASE 1
TESTCASE0 = False
test_title = 'TEST CASE 1'
PKseed = 'D5A45A4CED06403C5557E87113CB30EA'
PKroot = '8546AD883DDC43325A606C8B940C5EB1'
SKseed = '4B622DE1350119C45A9F2E2EF3DC5DF5'
SK_prf = '0A759D138CDFBD64C81CC7CC2F513345'
msg= \
'225D5CE2CEAC61930A07503FB59F7C2F936A3E075481DA3CA299A80F8C5DF9223A073E7B90E02EBF98CA2227EBA38C1AB2568209E46DBA961869C6F83983B17DCD49'
optrand = '08e25538484cd7f1613248fe6c9f6b4e'
# Expected hash of signature
expected_hash = '5224c48f065b48afaa5ebb14bd76aecfc503cecd48a071bba5e6ea366775c2a9'
else:
print("**INVALID TEST CASE NUMBER!!")
exit(1)
# CONSTANTS
SPX_DGST_BYTES=34
SPX_FORS_TREES = 33
SPX_FORS_HEIGHT = 6
t = 64
w = 16
SPX_WOTS_LEN = 35
SPX_TREE_HEIGHT = 22
print(f"About to compute SPHINCS+ signature for {test_title}...")
# Start composing the signature as a hex-encoded string with line breaks and comments
sig = ""
# Compute the randomizer R of 16 bytes using hex strings
R = PRF_msg(SK_prf, optrand, msg)
DPRINT("R =", R)
# R = b77b5397031e67eb585dba86b10b710b
if TESTCASE0: assert R == 'b77b5397031e67eb585dba86b10b710b'
sig += R + " # R\n"
# Compute H_msg from the message, public key and randomizer.
h = SHA256(R + PKseed + PKroot + msg)
h_msg = MGF1_SHA256(h, 34) # SPX_DGST_BYTES
DPRINT("H_msg:", h_msg)
if TESTCASE0: assert h_msg == '5b7eb772aecf04c74af07d9d9c1c1f8d3a90dcda00d5bab1dc28daecdc86eb87611e'
# Split into mhash (25 bytes), tree address (8 bytes) and idx_leaf (1 byte)
DPRINT("Split up H_msg...")
mhash = h_msg[:50]
tree_hex = h_msg[50:66]
leaf_hex = h_msg[66:68]
DPRINT(f"mhash='{mhash}'")
DPRINT(f"tree='{tree_hex}', leaf='{leaf_hex}'")
# Decode tree address and leaf index into integers
tree_addr = int(tree_hex, 16) & 0x7fffffffffffffff # 63 bits
DPRINT("tree_addr = 0x" + format(tree_addr, f'08x'))
idx_leaf = int(leaf_hex, 16) & 0x7 # 3 bits
DPRINT(f"tree_addr = {idx_leaf}")
if TESTCASE0: assert tree_addr == 0x28daecdc86eb8761
if TESTCASE0: assert idx_leaf == 6
# Interpret 25-byte mhash as 33 * 6-bit unsigned integers.
indices = [0] * SPX_FORS_TREES
m = bytes.fromhex(mhash)
DPRINT(f"mhash={m.hex()}")
offset = 0
for i in range(SPX_FORS_TREES):
indices[i] = 0
for j in range(SPX_FORS_HEIGHT):
indices[i] ^= ((m[offset >> 3] >> (offset & 0x7)) & 0x1) << j
offset += 1
DPRINT("message_to_indices:\n", [m for m in indices], sep='')
if TESTCASE0: assert indices[0] == 27 and indices[32] == 28
# Compute all 33 FORS signature sk values using the indices
# (we'll output these later)
fors_sig_sk = []
# Set up ADRS object
adrs = Adrs(Adrs.FORS_TREE, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
for i in range(SPX_FORS_TREES):
treeindex = i * t + indices[i]
adrs.setTreeIndex(treeindex)
DPRINT(f"ADRS={adrs.toHex()}")
sk = PRF(SKseed, adrs.toHex())
DPRINT(f"fors_sig_sk[{i}]={sk}")
fors_sig_sk.append(sk)
if TESTCASE0: assert fors_sig_sk[0] == '8c9f8091d1a1edbb6a8a041343c6e5c0'
if TESTCASE0: assert fors_sig_sk[32] == '446d9fc66808fcc5e0d47c0c381c7f9e'
# Compute the authpaths and root values for each of the k FORS trees
roots = []
# Compute FORS sk and pk values for each tree (i = 0,32)
for i in range(SPX_FORS_TREES):
# Set up ADRS object
adrs = Adrs(Adrs.FORS_TREE, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
leaves = []
for j in range(t):
# Note this computes all the sk's including the one at indices[i]
treeindex = i * t + j
adrs.setTreeIndex(treeindex)
DPRINT(f"ADRS={adrs.toHex()}")
sk = PRF(SKseed, adrs.toHex())
DPRINT(f"fors_sk[{i}][{j}]={sk}")
pk = F(PKseed, adrs.toHex(), sk)
DPRINT(f"fors_pk[{i}][{j}]={pk}")
leaves.append(pk)
DPRINT("leaves=", leaves, sep='\n')
# Compute the root value for this FORS tree
adrs = Adrs(Adrs.FORS_TREE, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
DPRINT(f"ADRS={adrs.toHex()}")
root = hash_root(leaves, adrs, PKseed, i * t)
DPRINT(f"root[{i}]={root}")
roots.append(root)
# and the authpath for indices[i]
idx = indices[i]
DPRINT(f"i={i} idx={idx}")
auth = authpath(leaves, adrs, PKseed, idx, i * t)
DPRINT(f"fors_auth_path[{i}]:")
[DPRINT(a) for a in auth]
# Output the sig_sk and authpath to the signature value
sig += fors_sig_sk[i] + format(f" # fors_sig_sk[{i}]\n")
sig += auth[0] + format(f" # fors_auth_path[{i}]\n")
sig += "\n".join(auth[1:]) + "\n"
DPRINT(sig)
# Compute the FORS public key given the roots of the k FORS trees.
DPRINT("roots:")
[DPRINT(r) for r in roots]
adrs = Adrs(Adrs.FORS_ROOTS, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
DPRINT(f"ADRS={adrs.toHex()}")
fors_pk = T_len(PKseed, adrs.toHex(), "".join(roots))
DPRINT(f"fors_pk={fors_pk}")
def wots_chain(msghex, show_csum=False):
# Split hex string into list of 4-bit nibbles
# (Cheat: we can just split hex string into separate digits)
mymsg = [int(x, 16) for x in msghex]
# Compute csum
csum = 0
for msgi in range(len(mymsg)):
csum += int(w - 1 - mymsg[msgi])
csum &= 0xfff # truncate to 12 bits
if show_csum: DPRINT(f"csum={csum:03x}")
mymsg.append((csum >> 8) & 0xF)
mymsg.append((csum >> 4) & 0xF)
mymsg.append((csum >> 0) & 0xF)
return mymsg
# Input FORS public key to first WOTS signature
wots_input = fors_pk
# Loop for each of 22 subtrees in the HT
for layer in range(SPX_TREE_HEIGHT):
DPRINT(f"input to HT at layer {layer}={wots_input}")
DPRINT(f"tree_addr={tree_addr:x} idx_leaf={idx_leaf}")
m = wots_chain(wots_input, False)
DPRINT(m)
DPRINT([hex(x) for x in m])
DPRINT(f"len={len(m)}")
# Compute the next WOTS signature.
# Set up ADRS object
adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
DPRINT(f"ADRS base={adrs.toHex()}")
htsigs = []
for idx in range(SPX_WOTS_LEN): # 35
DPRINT(f"Generate WOTS+ private key for i = {idx}")
# sk = PRF(SK.seed, ADRS)
adrs.setChainAddress(idx)
adrs_c = adrs.toHex()
DPRINT(f"ADRS={adrs_c}")
sk = PRF(SKseed, adrs_c)
DPRINT(f"sk={sk}")
# Compute F^m_i(sk)
mi = m[idx]
DPRINT(f"m[{idx}]={mi}")
x = sk
adrs_ht = Adrs.fromHex(adrs.toHex())
for i in range(mi):
adrs_ht.setHashAddress(i)
adrs_c = adrs_ht.toHex()
DPRINT(f"i={i} ADRS={adrs_c}")
DPRINT(f"in={x}")
x = F(PKseed, adrs_c, x)
DPRINT(f"F(PK.seed, ADRS, in)={x}")
DPRINT(f"ht_sig:{x}")
htsigs.append(x)
# Output this ht_sig (560 bytes) to the signature value
sig += htsigs[0] + format(f" # ht_sig[{layer}]\n")
sig += "\n".join(htsigs[1:]) + "\n"
leaves = []
# Compute all leaves of subtree at this layer
for this_leaf in range(8):
DPRINT(f"this_leaf={this_leaf}")
adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(this_leaf)
DPRINT(adrs.toHex())
heads = "" # concatenation of heads of WOTS+ chains
for chainaddr in range(35):
adrs.setChainAddress(chainaddr)
adrs_hex = adrs.toHex()
sk = PRF(SKseed, adrs_hex)
DPRINT(f"sk[{chainaddr}]={sk}")
pk = chain(sk, 0, w - 1, PKseed, adrs_hex,
showdebug=(DEBUG and (chainaddr < 2 or chainaddr == 34)));
DPRINT(f"pk={pk}")
heads += pk
DPRINT(f"Input to thash:\n{heads}")
# for thash,
wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=layer)
wots_pk_adrs.setTreeAddress(tree_addr)
wots_pk_adrs.setKeyPairAddress(this_leaf)
wots_pk_addr_hex = wots_pk_adrs.toHex()
DPRINT(f"wots_pk_addr={wots_pk_addr_hex}")
leaf = T_len(PKseed, wots_pk_addr_hex, heads)
DPRINT(f"leaf[{leaf}]={leaf}")
leaves.append(leaf)
DPRINT(leaves)
# Compute the root node of Merkle tree using H
# Start with 8 leaf values in array
adrs = Adrs(Adrs.TREE, layer=layer)
adrs.setTreeAddress(tree_addr)
DPRINT(f"ADRS={adrs.toHex()}")
root = hash_root(leaves, adrs, PKseed)
DPRINT(f"root={root}")
if TESTCASE0 and layer == 0: assert root == 'f2ec3b2ae23a50355d057b97df65c8bc'
# Compute the authentication path from leaf_idx
DPRINT(f"Computing authpath for leaf index {idx_leaf}...")
auth = authpath(leaves, adrs, PKseed, idx_leaf)
DPRINT("authpath:")
[DPRINT(a) for a in auth]
if TESTCASE0 and layer == 0: assert auth[2] == '77a2617d410d8f1acd1fbc29830e1a51'
# Output this authpath to the signature value
sig += auth[0] + format(f" # ht_auth_path[{layer}]\n")
sig += "\n".join(auth[1:]) + "\n"
# Set next wots_input to root and change tree_addr for next layer
wots_input = root
idx_leaf = tree_addr & 0x7 # (2^3 - 1)
tree_addr >>= 3
# Loop for next higher subtree...
# At the end the final root value MUST equal PK.root
print(f"Final root={wots_input}")
print(f"Expecting ={PKroot}")
assert wots_input.lower() == PKroot.lower()
# Print out and check the signature
print(f"sig:\n{sig}")
print("sig lines =", sig.count('\n'), "(expecting 1068)")
# Strip sig string down to pure hex digits
sighex = sig
sighex = re.sub(r'\s+\#.*?$', '', sighex, flags=re.MULTILINE)
sighex = sighex.replace('\n', '')
print("sighex len =", len(sighex), "(expecting 17088 x 2 = 34176)")
# Compute SHA-256 of signature
hash_sig = SHA256(sighex)
print(f"SHA256(sighex)={hash_sig}")
print(f"Expected ={expected_hash}")