# @file spx_makesigs100.py (2023-03-18T10:13Z)
# @author David Ireland <www.di-mgt.com.au/contact>
# @copyright 2023 DI Management Services Pty Ltd
# @license Apache-2.0
"""Reproduce all 100 SPHINCS+-SHA2-128f-simple test cases."""
from spx_sha256 import *
from spx_adrs import Adrs
from spx_util import *
import re # For manipulation of sig string
import json
"""
Reproduce all 100 SPHINCS+-SHA2-128f-simple test cases 0..99 from
NIST-PQ-Submission-SPHINCS-20201001/KAT/sphincs-sha256-128f-simple/PQCsignKAT_64.rsp
CAUTION: Many hard-coded constants!
Lots of debugging statements. Set DEBUG=False to turn off.
NOTE: the value of `optrand` can only be found by running the implementation code.
"""
# DEBUGGING
DEBUG = False
DPRINT = print if DEBUG else lambda *a, **kk: None
def make_sig(count, PKseed, PKroot, SKseed, SK_prf, msg, optrand, expected_hash):
# CONSTANTS
SPX_DGST_BYTES = 34
SPX_FORS_TREES = 33
SPX_FORS_HEIGHT = 6
SPX_WOTS_W = 16
w = SPX_WOTS_W
t = (1 << SPX_FORS_HEIGHT) # 2^6 = 64
SPX_WOTS_LEN = 35
SPX_D = 22
SPX_HT_HEIGHT = 3
SPX_HT_LEAVES = (1 << SPX_HT_HEIGHT) # 2^3 = 8
print(f"\nAbout to compute SPHINCS+ signature for TEST CASE {count}...")
# Start composing the signature as a hex-encoded string with line breaks and comments
sig = ""
# Compute the randomizer R of 16 bytes using hex strings
R = PRF_msg(SK_prf, optrand, msg)
DPRINT("R =", R)
sig += R + " # R\n"
# Compute H_msg from the message, public key and randomizer.
h = SHA256(R + PKseed + PKroot + msg)
h_msg = MGF1_SHA256(h, SPX_DGST_BYTES) # 34
DPRINT("H_msg:", h_msg)
# Split into mhash (25 bytes), tree address (8 bytes) and idx_leaf (1 byte)
DPRINT("Split up H_msg...")
mhash = h_msg[:50]
tree_hex = h_msg[50:66]
leaf_hex = h_msg[66:68]
DPRINT(f"mhash='{mhash}'")
DPRINT(f"tree='{tree_hex}', leaf='{leaf_hex}'")
# Decode tree address and leaf index into integers
tree_addr = int(tree_hex, 16) & 0x7fffffffffffffff # 63 bits
DPRINT("tree_addr = 0x" + format(tree_addr, f'08x'))
idx_leaf = int(leaf_hex, 16) & 0x7 # 3 bits
DPRINT(f"tree_addr = {idx_leaf}")
# Interpret 25-byte mhash as 33 * 6-bit unsigned integers.
indices = [0] * SPX_FORS_TREES
m = bytes.fromhex(mhash)
DPRINT(f"mhash={m.hex()}")
offset = 0
for i in range(SPX_FORS_TREES):
indices[i] = 0
for j in range(SPX_FORS_HEIGHT):
indices[i] ^= ((m[offset >> 3] >> (offset & 0x7)) & 0x1) << j
offset += 1
DPRINT("message_to_indices:\n", [m for m in indices], sep='')
# Compute all 33 FORS signature sk values using the indices
# (we'll output these later)
fors_sig_sk = []
# Set up ADRS object
adrs = Adrs(Adrs.FORS_TREE, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
for i in range(SPX_FORS_TREES):
treeindex = i * t + indices[i]
adrs.setTreeIndex(treeindex)
DPRINT(f"ADRS={adrs.toHex()}")
sk = PRF(SKseed, adrs.toHex())
DPRINT(f"fors_sig_sk[{i}]={sk}")
fors_sig_sk.append(sk)
# Compute the authpaths and root values for each of the k FORS trees
roots = []
# Compute FORS sk and pk values for each tree (i = 0,32)
for i in range(SPX_FORS_TREES):
# Set up ADRS object
adrs = Adrs(Adrs.FORS_TREE, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
leaves = []
for j in range(t):
# Note this computes all the sk's including the one at indices[i]
treeindex = i * t + j
adrs.setTreeIndex(treeindex)
DPRINT(f"ADRS={adrs.toHex()}")
sk = PRF(SKseed, adrs.toHex())
DPRINT(f"fors_sk[{i}][{j}]={sk}")
pk = F(PKseed, adrs.toHex(), sk)
DPRINT(f"fors_pk[{i}][{j}]={pk}")
leaves.append(pk)
DPRINT("leaves=", leaves, sep='\n')
# Compute the root value for this FORS tree
adrs = Adrs(Adrs.FORS_TREE, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
DPRINT(f"ADRS={adrs.toHex()}")
root = hash_root(leaves, adrs, PKseed, i * t)
DPRINT(f"root[{i}]={root}")
roots.append(root)
# and the authpath for indices[i]
idx = indices[i]
DPRINT(f"i={i} idx={idx}")
auth = authpath(leaves, adrs, PKseed, idx, i * t)
DPRINT(f"fors_auth_path[{i}]:")
[DPRINT(a) for a in auth]
# Output the sig_sk and authpath to the signature value
sig += fors_sig_sk[i] + format(f" # fors_sig_sk[{i}]\n")
sig += auth[0] + format(f" # fors_auth_path[{i}]\n")
sig += "\n".join(auth[1:]) + "\n"
DPRINT(sig)
# Compute the FORS public key given the roots of the k FORS trees.
DPRINT("roots:")
[DPRINT(r) for r in roots]
adrs = Adrs(Adrs.FORS_ROOTS, layer=0)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
DPRINT(f"ADRS={adrs.toHex()}")
fors_pk = T_len(PKseed, adrs.toHex(), "".join(roots))
DPRINT(f"fors_pk={fors_pk}")
def wots_chain(msghex, show_csum=False):
# Split hex string into list of 4-bit nibbles
# (Cheat: we can just split hex string into separate digits)
assert w == 16 # But check w really is 16
mymsg = [int(x, 16) for x in msghex]
# Compute csum
csum = 0
for msgi in range(len(mymsg)):
csum += int(w - 1 - mymsg[msgi])
csum &= 0xfff # truncate to 12 bits
if show_csum:
DPRINT(f"csum={csum:03x}")
mymsg.append((csum >> 8) & 0xF)
mymsg.append((csum >> 4) & 0xF)
mymsg.append((csum >> 0) & 0xF)
return mymsg
# Input FORS public key to first WOTS signature
wots_input = fors_pk
# Loop for each of 22 subtrees in the HT
for layer in range(SPX_D):
DPRINT(f"input to HT at layer {layer}={wots_input}")
DPRINT(f"tree_addr={tree_addr:x} idx_leaf={idx_leaf}")
m = wots_chain(wots_input, False)
DPRINT(m)
DPRINT([hex(x) for x in m])
DPRINT(f"len={len(m)}")
# Compute the next WOTS signature.
# Set up ADRS object
adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(idx_leaf)
DPRINT(f"ADRS base={adrs.toHex()}")
htsigs = []
for idx in range(SPX_WOTS_LEN): # 35
DPRINT(f"Generate WOTS+ private key for i = {idx}")
# sk = PRF(SK.seed, ADRS)
adrs.setChainAddress(idx)
adrs_c = adrs.toHex()
DPRINT(f"ADRS={adrs_c}")
sk = PRF(SKseed, adrs_c)
DPRINT(f"sk={sk}")
# Compute F^m_i(sk)
mi = m[idx]
DPRINT(f"m[{idx}]={mi}")
x = sk
adrs_ht = Adrs.fromHex(adrs.toHex())
for i in range(mi):
adrs_ht.setHashAddress(i)
adrs_c = adrs_ht.toHex()
DPRINT(f"i={i} ADRS={adrs_c}")
DPRINT(f"in={x}")
x = F(PKseed, adrs_c, x)
DPRINT(f"F(PK.seed, ADRS, in)={x}")
DPRINT(f"ht_sig:{x}")
htsigs.append(x)
# Output this ht_sig (560 bytes) to the signature value
sig += htsigs[0] + format(f" # ht_sig[{layer}]\n")
sig += "\n".join(htsigs[1:]) + "\n"
leaves = []
# Compute all leaves of subtree at this layer
for this_leaf in range(SPX_HT_LEAVES):
DPRINT(f"this_leaf={this_leaf}")
adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
adrs.setTreeAddress(tree_addr)
adrs.setKeyPairAddress(this_leaf)
DPRINT(adrs.toHex())
heads = "" # concatenation of heads of WOTS+ chains
for chainaddr in range(SPX_WOTS_LEN):
adrs.setChainAddress(chainaddr)
adrs_hex = adrs.toHex()
sk = PRF(SKseed, adrs_hex)
DPRINT(f"sk[{chainaddr}]={sk}")
pk = chain(sk, 0, w - 1, PKseed, adrs_hex,
showdebug=(DEBUG and (chainaddr < 2 or chainaddr == (SPX_WOTS_LEN-1))))
DPRINT(f"pk={pk}")
heads += pk
DPRINT(f"Input to thash:\n{heads}")
# for thash,
wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=layer)
wots_pk_adrs.setTreeAddress(tree_addr)
wots_pk_adrs.setKeyPairAddress(this_leaf)
wots_pk_addr_hex = wots_pk_adrs.toHex()
DPRINT(f"wots_pk_addr={wots_pk_addr_hex}")
leaf = T_len(PKseed, wots_pk_addr_hex, heads)
DPRINT(f"leaf[{leaf}]={leaf}")
leaves.append(leaf)
DPRINT(leaves)
# Compute the root node of Merkle tree using H
# Start with 8 leaf values in array
adrs = Adrs(Adrs.TREE, layer=layer)
adrs.setTreeAddress(tree_addr)
DPRINT(f"ADRS={adrs.toHex()}")
root = hash_root(leaves, adrs, PKseed)
DPRINT(f"root={root}")
# Compute the authentication path from leaf_idx
DPRINT(f"Computing authpath for leaf index {idx_leaf}...")
auth = authpath(leaves, adrs, PKseed, idx_leaf)
DPRINT("authpath:")
[DPRINT(a) for a in auth]
# Output this authpath to the signature value
sig += auth[0] + format(f" # ht_auth_path[{layer}]\n")
sig += "\n".join(auth[1:]) + "\n"
# Set next wots_input to root and change tree_addr for next layer
wots_input = root
idx_leaf = tree_addr & 0x7 # (2^3 - 1)
tree_addr >>= 3
# Loop for next higher subtree...
# At the end the final root value MUST equal PK.root
print(f"Final root={wots_input}")
print(f"Expecting ={PKroot}")
assert wots_input.lower() == PKroot.lower()
# Print out and check the signature
print(f"sig:\n{sig}")
print("sig lines =", sig.count('\n'), "(expecting 1068)")
# Strip sig string down to pure hex digits
sighex = sig
sighex = re.sub(r'\s+#.*?$', '', sighex, flags=re.MULTILINE)
sighex = sighex.replace('\n', '')
print("sighex len =", len(sighex), "(expecting 17088 x 2 = 34176)")
# Compute SHA-256 of signature
hash_sig = SHA256(sighex)
print(f"SHA256(sighex)={hash_sig}")
# Check hash matches what we expect
print(f"Expected ={expected_hash}")
assert hash_sig == expected_hash
if __name__ == '__main__':
# Read in json file with input data for 100 test cases
with open('spx_inputkat100.json', 'r') as f:
katall = json.load(f)
# Do each test case
for k in katall:
# def make_sig(count, PKseed, PKroot, SKseed, SK_prf, msg, optrand, expected_hash):
make_sig(k['count'], k['PKseed'], k['PKroot'], k['SKseed'], k['SK_prf'], k['msg'], k['optrand'], k['hash'])
print("ALL DONE.")