# @file slhdsa_makesig.py
# @version 1.1.0 (2026-02-15T08:23Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0

"""Make a SLH-DSA signature using all 12 defined parameter sets for FIPS.205 final (2024-08-13).

This is meant as a demonstration of how SLH-DSA works, not an efficient or secure implementation.
It uses hex-encoded representations of all variables throughout, so obvs slower.
Ref: NIST FIPS.205 "Stateless Hash-Based Digital Signature Standard" SLH-DSA (2024-08-13)
https://doi.org/10.6028/NIST.FIPS.205
"""

import slh_hashfuncs as hashfuncs
from slh_adrs import Adrs
from slh_tree import hash_root, authpath, chain
from slh_params import params
import re  # For manipulation of sig string
import os

# DEBUGGING (set True to turn on debugging)
DEBUG = False
DPRINT = print if DEBUG else lambda *a, **k: None
DEBUGA = False  # More detailed debug
DPRINTA = print if DEBUGA else lambda *a, **k: None


# [FIPS.205] approves 12 parameter sets for use with SLH-DSA. Ref: Section 11. SLH-DSA Parameter Sets Table 1.
params_sha2_128s = params(name='SLH-DSA-SHA2-128s', n=16, h=63, d=7, a=12, k=14, lgw=4, m=30, compr=True,
                          H_msg=hashfuncs.H_msg_sha256, PRF=hashfuncs.PRF_sha256, PRF_msg=hashfuncs.PRF_msg_sha256, F=hashfuncs.F_sha256, H=hashfuncs.H_sha256, T_len=hashfuncs.T_len_sha256)
params_sha2_128f = params(name='SLH-DSA-SHA2-128f', n=16, h=66, d=22, a=6, k=33, lgw=4, m=34, compr=True,
                          H_msg=hashfuncs.H_msg_sha256, PRF=hashfuncs.PRF_sha256, PRF_msg=hashfuncs.PRF_msg_sha256, F=hashfuncs.F_sha256, H=hashfuncs.H_sha256, T_len=hashfuncs.T_len_sha256)
params_sha2_192s = params(name='SLH-DSA-SHA2-192s', n=24, h=63, d=7, a=14, k=17, lgw=4, m=39, compr=True,
                          H_msg=hashfuncs.H_msg_sha512, PRF=hashfuncs.PRF_sha256, PRF_msg=hashfuncs.PRF_msg_sha512, F=hashfuncs.F_sha256, H=hashfuncs.H_sha512, T_len=hashfuncs.T_len_sha512)
params_sha2_192f = params(name='SLH-DSA-SHA2-192f', n=24, h=66, d=22, a=8, k=33, lgw=4, m=42, compr=True,
                          H_msg=hashfuncs.H_msg_sha512, PRF=hashfuncs.PRF_sha256, PRF_msg=hashfuncs.PRF_msg_sha512, F=hashfuncs.F_sha256, H=hashfuncs.H_sha512, T_len=hashfuncs.T_len_sha512)
params_sha2_256s = params(name='SLH-DSA-SHA2-256s', n=32, h=64, d=8, a=14, k=22, lgw=4, m=47, compr=True,
                          H_msg=hashfuncs.H_msg_sha512, PRF=hashfuncs.PRF_sha256, PRF_msg=hashfuncs.PRF_msg_sha512, F=hashfuncs.F_sha256, H=hashfuncs.H_sha512, T_len=hashfuncs.T_len_sha512)
params_sha2_256f = params(name='SLH-DSA-SHA2-256f', n=32, h=68, d=17, a=9, k=35, lgw=4, m=49, compr=True,
                          H_msg=hashfuncs.H_msg_sha512, PRF=hashfuncs.PRF_sha256, PRF_msg=hashfuncs.PRF_msg_sha512, F=hashfuncs.F_sha256, H=hashfuncs.H_sha512, T_len=hashfuncs.T_len_sha512)
params_shake_128s = params(name='SLH-DSA-SHAKE-128s', n=16, h=63, d=7, a=12, k=14, lgw=4, m=30, compr=False,
                           H_msg=hashfuncs.H_msg_shake, PRF=hashfuncs.PRF_shake, PRF_msg=hashfuncs.PRF_msg_shake, F=hashfuncs.F_shake, H=hashfuncs.H_shake, T_len=hashfuncs.T_len_shake)
params_shake_128f = params(name='SLH-DSA-SHAKE-128f', n=16, h=66, d=22, a=6, k=33, lgw=4, m=34, compr=False,
                           H_msg=hashfuncs.H_msg_shake, PRF=hashfuncs.PRF_shake, PRF_msg=hashfuncs.PRF_msg_shake, F=hashfuncs.F_shake, H=hashfuncs.H_shake, T_len=hashfuncs.T_len_shake)
params_shake_192s = params(name='SLH-DSA-SHAKE-192s', n=24, h=63, d=7, a=14, k=17, lgw=4, m=39, compr=False,
                           H_msg=hashfuncs.H_msg_shake, PRF=hashfuncs.PRF_shake, PRF_msg=hashfuncs.PRF_msg_shake, F=hashfuncs.F_shake, H=hashfuncs.H_shake, T_len=hashfuncs.T_len_shake)
params_shake_192f = params(name='SLH-DSA-SHAKE-192f', n=24, h=66, d=22, a=8, k=33, lgw=4, m=42, compr=False,
                           H_msg=hashfuncs.H_msg_shake, PRF=hashfuncs.PRF_shake, PRF_msg=hashfuncs.PRF_msg_shake, F=hashfuncs.F_shake, H=hashfuncs.H_shake, T_len=hashfuncs.T_len_shake)
params_shake_256s = params(name='SLH-DSA-SHAKE-256s', n=32, h=64, d=8, a=14, k=22, lgw=4, m=47, compr=False,
                           H_msg=hashfuncs.H_msg_shake, PRF=hashfuncs.PRF_shake, PRF_msg=hashfuncs.PRF_msg_shake, F=hashfuncs.F_shake, H=hashfuncs.H_shake, T_len=hashfuncs.T_len_shake)
params_shake_256f = params(name='SLH-DSA-SHAKE-256f', n=32, h=68, d=17, a=9, k=35, lgw=4, m=49, compr=False,
                           H_msg=hashfuncs.H_msg_shake, PRF=hashfuncs.PRF_shake, PRF_msg=hashfuncs.PRF_msg_shake, F=hashfuncs.F_shake, H=hashfuncs.H_shake, T_len=hashfuncs.T_len_shake)

params_dict = {  # Keyed on name string (NOTE lowercase 'f' and 's')
"SLH-DSA-SHA2-128f": params_sha2_128f,   
"SLH-DSA-SHA2-128s": params_sha2_128s,   
"SLH-DSA-SHA2-192f": params_sha2_192f,   
"SLH-DSA-SHA2-192s": params_sha2_192s,   
"SLH-DSA-SHA2-256f": params_sha2_256f,   
"SLH-DSA-SHA2-256s": params_sha2_256s,   
"SLH-DSA-SHAKE-128f": params_shake_128f,   
"SLH-DSA-SHAKE-128s": params_shake_128s,   
"SLH-DSA-SHAKE-192f": params_shake_192f,   
"SLH-DSA-SHAKE-192s": params_shake_192s,   
"SLH-DSA-SHAKE-256f": params_shake_256f,   
"SLH-DSA-SHAKE-256s": params_shake_256s,   
}


def show_params(param_name=""):
    if param_name:  # Just select one
        d = {param_name: params_dict[param_name]}
    else:  # Show all
        d = params_dict
    print("                   n  h d h' a k lgw m")
    for params in d.values():
        print(params.name, params.n, params.h, params.d, params.h//params.d, params.a, params.k, params.lgw, params.m)
        print(params.H_msg.__name__, params.PRF.__name__, params.PRF_msg.__name__, params.F.__name__, params.H.__name__, params.T_len.__name__)



def wots_chain(msghex, show_csum=False):
    """Convert msg of length len_1 bytes to base w and append checksum of len_2 bytes.
    INPUT: msghex Message encoded in hex.
    OUTPUT: Array of len = len_1 + len_2 integers.
    """
    w = 16  # always => lgw = 4 always
    # Split hex string into list of 4-bit nibbles
    # (Cheat: we can just split base-16 hex string into separate digits)
    mymsg = [int(x, 16) for x in msghex]
    # Compute csum
    csum = 0
    for msgi in range(len(mymsg)):
        csum += int(w - 1 - mymsg[msgi])
    csum &= 0xfff   # truncate to 12 bits
    if show_csum: DPRINT(f"csum=0x{csum:03x}")
    mymsg.append((csum >> 8) & 0xF)
    mymsg.append((csum >> 4) & 0xF)
    mymsg.append((csum >> 0) & 0xF)
    return mymsg


def base_2_b(X, b, out_len):
    """Compute the base 2^b representation of X.
    INPUT: Byte string X, integer b, output length out_len.
    OUTPUT: Array of out_len integers in the range 0 <= i < 2^b.
    FIPS.205 Algorithm 4: base_2^b(X, b, out_len)
    """
    baseb = [0] * out_len
    i = 0
    bits = 0
    total = 0
    modulus = 1 << b  # 2^b
    mask = modulus - 1
    # NOTE: a mod 2^b === a & (2^b - 1)

    for out in range(out_len):
        while bits < b:
            total = (total << 8) + X[i]
            i += 1
            bits += 8
        bits = bits - b
        baseb[out] = (total >> bits) & mask
        total &= mask
    return baseb


def slhdsa_sign_internal(sk: str, msg: str, addrnd: str, params, Rok='', Hmsgok='', fors_sig_sk_0='', fors_pk_ok='', sig_hash_ok=''):
    """Generic SLH-DSA sign_internal function. All parameters hex encoded.
    If `addrnd` is empty, use deterministic variant. We do not generate random bits here.
    """

    DPRINT("\n" + params.name, "n =", params.n)

    n = params.n
    # Parse sk
    SKseed = sk[:n*2]
    SKprf = sk[n*2:n*4]
    PKseed = sk[n*4:n*6]
    PKroot = sk[n*6:]
    DPRINT("sk=", sk)
    DPRINT(SKseed, SKprf, PKseed, PKroot)

    # Start composing the signature as a hex-encoded string with line breaks and comments
    sig = ""

    # If hedged, use additional random as opt_rand; else use PK.seed if deterministic
    if not addrnd:
        optrand = PKseed
    elif len(addrnd) != 2 * n:
        raise Exception("Invalid addrand parameter: must be exactly n bytes")
    else:
        optrand = addrnd

    # Compute the randomizer R of n bytes using hex strings
    R = params.PRF_msg(SKprf, optrand, msg, params)
    DPRINT("R =", R)
    DPRINT("R'=", Rok)
    if len(Rok) > 0: assert Rok == R
    sig += R + " # R\n"

    # Compute H_msg from the message, public key and randomizer.
    h_msg = params.H_msg(R, PKseed, PKroot, msg, params)
    DPRINT("H_msg :", h_msg)
    # H_msg = mhash + tree + idx_leaf
    DPRINT("H_msg':", Hmsgok)
    if len(Hmsgok) > 0: assert Hmsgok == h_msg

    # Split into mhash (md), tree address and idx_leaf
    md_bits = params.k * params.a
    idx_tree_bits = params.h - params.h // params.d
    idx_leaf_bits = params.h // params.d
    md_len = (md_bits + 7) // 8
    idx_tree_len = (idx_tree_bits + 7) // 8
    idx_leaf_len = (idx_leaf_bits + 7) // 8
    assert md_len + idx_tree_len + idx_leaf_len == params.m
    DPRINT("Split up H_msg...")
    # compute message digest and index
    mhash = h_msg[:md_len * 2]
    tree_hex = h_msg[md_len * 2:(md_len + idx_tree_len) * 2]
    leaf_hex = h_msg[(md_len + idx_tree_len) * 2:(md_len + idx_tree_len + idx_leaf_len) * 2]
    DPRINT(f"mhash='{mhash}'")
    DPRINT(f"tree='{tree_hex}', leaf='{leaf_hex}'")
    # Decode tree address and leaf index into integers
    tree_mask = (1 << idx_tree_bits) - 1
    tree_addr = int(tree_hex, 16) & tree_mask
    DPRINT("tree_addr = 0x" + format(tree_addr, f'08x'))
    leaf_mask = (1 << idx_leaf_bits) - 1
    idx_leaf = int(leaf_hex, 16) & leaf_mask
    DPRINT(f"idx_leaf = {idx_leaf}")

    # Interpret mhash as k x a-bit integers using the base_2^b algorithm.
    indices = base_2_b(bytes.fromhex(mhash), params.a, params.k)
    DPRINT("message_to_indices:\n", [m for m in indices], sep='')

    # Compute all k FORS signature sk values using the indices
    # (we'll output these later)
    fors_sig_sk = []
    # Set up ADRS object
    adrs = Adrs(Adrs.FORS_PRF, layer=0)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(idx_leaf)
    DPRINT("base adrs =", adrs.toHex(params.compr))

    t = (1 << params.a)  # t = 2^a
    for i in range(params.k):  # SPX_FORS_TREES
        treeindex = i * t + indices[i]
        adrs.setTreeIndex(treeindex)
        DPRINTA(f"ADRS={adrs.toHexSP(params.compr)}")
        sk = params.PRF(PKseed, SKseed, adrs.toHex(params.compr), params)
        DPRINTA(f"fors_sig_sk[{i}]={sk}")
        fors_sig_sk.append(sk)

    # Check we have the first FORS sig sk correct
    if len(fors_sig_sk_0) > 0: assert(fors_sig_sk_0 == fors_sig_sk[0])

    # Compute the authpaths and root values for each of the k FORS trees
    roots = []
    # Compute FORS sk and pk values for each tree (i = 0..k-1)
    for i in range(params.k):  # SPX_FORS_TREES
        # Set up ADRS object
        adrs = Adrs(Adrs.FORS_TREE, layer=0)
        adrs.setTreeAddress(tree_addr)
        adrs.setKeyPairAddress(idx_leaf)
        leaves = []
        DPRINT(f"About to compute {t} sk's for tree {i}")
        for j in range(t):
            # Note this computes all the sk's and pk's including the one at indices[i]
            treeindex = i * t + j
            adrs.setTreeIndex(treeindex)
            # DPRINT(f"ADRS={adrs.toHex()}")
            # fors_sk_gen...
            skAdrs = adrs.copy()
            skAdrs.setType(Adrs.FORS_PRF)
            skAdrs.setTreeIndex(adrs.getTreeIndex())
            skAdrs.setKeyPairAddress(adrs.getKeyPairAddress())
            DPRINTA(f"skADRS={skAdrs.toHexSP(params.compr)}")
            # [v3.1] use FORS_PRF for sk but FORS_TREE for pk
            sk = params.PRF(PKseed, SKseed, skAdrs.toHex(params.compr), params)
            DPRINTA(f"fors_sk[{i}][{j}]={sk}")
            DPRINTA(f"pkADRS={adrs.toHexSP(params.compr)}")
            pk = params.F(PKseed, adrs.toHex(params.compr), sk, params)
            DPRINTA(f"fors_pk[{i}][{j}]={pk}")
            leaves.append(pk)

        DPRINT("# leaves=", len(leaves))
        DPRINT(leaves)
        # Compute the root value for this FORS tree
        adrs = Adrs(Adrs.FORS_TREE, layer=0)
        adrs.setTreeAddress(tree_addr)
        adrs.setKeyPairAddress(idx_leaf)
        DPRINT(f"ADRS={adrs.toHex(params.compr)}")
        DPRINT("about to call hash_root...")
        root = hash_root(leaves, adrs, PKseed, params, i * t)
        DPRINT(f"root[{i}]={root}")
        roots.append(root)

        # and the authpath for indices[i]
        idx = indices[i]
        DPRINT(f"i={i} idx={idx}")
        auth = authpath(leaves, adrs, PKseed, idx, params, i * t)
        DPRINT(f"fors_auth_path[{i}]:...")
        [DPRINT(a) for a in auth]
        # Output the sig_sk and authpath to the signature value
        sig += fors_sig_sk[i] + format(f" # fors_sig_sk[{i}]\n")
        sig += auth[0] + format(f" # fors_auth_path[{i}]\n")
        sig += "\n".join(auth[1:]) + "\n"


    DPRINTA("sig (so far):\n" + sig)

    # Compute the FORS public key given the roots of the k FORS trees.
    DPRINT("roots:")
    [DPRINT(r) for r in roots]
    adrs = Adrs(Adrs.FORS_ROOTS, layer=0)
    adrs.setTreeAddress(tree_addr)
    adrs.setKeyPairAddress(idx_leaf)
    DPRINT(f"ADRS={adrs.toHex(params.compr)}")
    fors_pk = params.T_len(PKseed, adrs.toHex(params.compr), "".join(roots), params)
    DPRINT(f"fors_pk={fors_pk}")
    if len(fors_pk_ok) > 0: assert(fors_pk_ok == fors_pk)

    # COMPUTE THE HT_SIG
    # Compute lengths and heights for WOTS...
    wots_len1 = (8 * params.n + params.lgw - 1) // params.lgw  # ceil(8 * n / lgw)
    wots_len2 = 3   # always 3 for n=all(16,24,32)
    wots_len = wots_len1 + wots_len2
    DPRINT("wots_len =", wots_len)
    tree_ht = params.h // params.d
    assert tree_ht * params.d == params.h  # Check d divides h exactly

    # Input FORS public key to first WOTS signature
    wots_input = fors_pk

    # Loop for each of d subtrees in the HT
    for layer in range(params.d):  # SPX_D
        DPRINT(f"input to HT at layer {layer}={wots_input}")
        DPRINTA(f"tree_addr={tree_addr:x} idx_leaf={idx_leaf:x}")
        m = wots_chain(wots_input, False)
        DPRINT(m)
        DPRINT([hex(x) for x in m])
        DPRINT(f"len={len(m)}")

        # Compute the next WOTS signature.
        # [v3.1] Use separate skADRS of type WOTS_PRF to generate sk using PRF
        # but keep ADRS of type WOTS_HASH to derive chain
        # Set up ADRS object
        adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
        adrs.setTreeAddress(tree_addr)
        adrs.setKeyPairAddress(idx_leaf)
        DPRINT(f"ADRS base={adrs.toHex(params.compr)}")
        skAdrs = adrs.copy()
        skAdrs.setType(Adrs.WOTS_PRF)
        skAdrs.setKeyPairAddress(adrs.getKeyPairAddress())
    
        htsigs = []
        for idx in range(wots_len):  # (35,51,67)
            DPRINTA(f"Generate WOTS+ private key for i = {idx}")
            # sk = PRF(PK.seed, SK.seed, ADRS)
            skAdrs.setChainAddress(idx)
            DPRINTA(f"ADRS={skAdrs.toHexSP(params.compr)}")
            sk = params.PRF(PKseed, SKseed, skAdrs.toHex(params.compr), params)
            DPRINTA(f"wots_sk[{idx}]={sk}")
        
            # Compute F^m_i(sk)
            adrs.setChainAddress(idx)
            mi = m[idx]
            DPRINTA(f"m[{idx}]={mi}")
            x = sk
            adrs_ht = Adrs.fromHex(adrs.toHex(params.compr))
            for i in range(mi):
                adrs_ht.setHashAddress(i)
                adrs_c = adrs_ht.toHex(params.compr)
                DPRINTA(f"i={i} ADRS={adrs_ht.toHexSP()}")
                DPRINTA(f"in={x}")
                x = params.F(PKseed, adrs_c, x, params)
                DPRINTA(f"F(PK.seed, ADRS, in)={x}")

            DPRINTA(f"ht_sig[{layer}][{idx}]:{x}")
            htsigs.append(x)

        DPRINT(f"ht_sig[{layer}] ({len(htsigs) * params.n} bytes)=\n{htsigs}")
        # Output this ht_sig (wots_len * n bytes) to the signature value
        sig += htsigs[0] + format(f" # ht_sig[{layer}]\n")
        sig += "\n".join(htsigs[1:]) + "\n"

        w = (1 << params.lgw)  # w = 2^lgw = 16 (always)
        leaf_mask = (1 << tree_ht) - 1
        # Compute all leaves of subtree at this layer 2^{h/d}
        nleaves = (1 << tree_ht)
        DPRINT("nleaves =", nleaves)
        leaves = []
        for this_leaf in range(nleaves):
            DPRINTA(f"this_leaf={this_leaf}")
            adrs = Adrs(Adrs.WOTS_HASH, layer=layer)
            adrs.setTreeAddress(tree_addr)
            adrs.setKeyPairAddress(this_leaf)
            DPRINTA(adrs.toHexSP(params.compr))
            skAdrs = adrs.copy()
            skAdrs.setType(Adrs.WOTS_PRF)
            skAdrs.setKeyPairAddress(adrs.getKeyPairAddress())
            heads = ""  # concatenation of heads of WOTS+ chains
            for chainaddr in range(wots_len):
                # [v3.1] Use WOTS_PRF to create sk, but WOTS_HASH for pk
                skAdrs.setChainAddress(chainaddr)
                # Compute secret value for chain i
                sk = params.PRF(PKseed, SKseed, skAdrs.toHex(params.compr), params)
                DPRINTA(f"sk[{chainaddr}]={sk}")
                # Compute public value for chain i
                adrs.setChainAddress(chainaddr)
                pk = chain(sk, 0, w - 1, PKseed, adrs.toHex(params.compr), params,
                           showdebug=(DEBUGA and (chainaddr < 2 or chainaddr == wots_len - 1)))
                DPRINTA(f"pk={pk}")
                heads += pk

            DPRINT(f"Input to thash:\n{heads}")
            # for thash,
            wots_pk_adrs = Adrs(Adrs.WOTS_PK, layer=layer)
            wots_pk_adrs.setTreeAddress(tree_addr)
            wots_pk_adrs.setKeyPairAddress(this_leaf)
            DPRINT(f"wots_pk_addr={wots_pk_adrs.toHexSP(params.compr)}")
            # Compress public key
            leaf = params.T_len(PKseed, wots_pk_adrs.toHex(params.compr), heads, params)
            DPRINT(f"leaf[{this_leaf}]={leaf}")
            leaves.append(leaf)

        DPRINT(leaves)

        # Compute the root node of Merkle tree using H
        # Start with 2^{h/d} leaf values in array
        adrs = Adrs(Adrs.TREE, layer=layer)
        adrs.setTreeAddress(tree_addr)
        DPRINT(f"ADRS={adrs.toHex(params.compr)}")
        root = hash_root(leaves, adrs, PKseed, params)
        DPRINT(f"root={root}")

        # Compute the authentication path from leaf_idx
        DPRINT(f"Computing authpath for leaf index {idx_leaf}...")
        auth = authpath(leaves, adrs, PKseed, idx_leaf, params)
        DPRINT("authpath:")
        [DPRINT(a) for a in auth]

        # Output this authpath to the signature value
        sig += auth[0] + format(f" # ht_auth_path[{layer}]\n")
        sig += "\n".join(auth[1:]) + "\n"

        # Set next wots_input to root and change tree_addr for next layer
        wots_input = root
        idx_leaf = tree_addr & leaf_mask  # (2^{h'} - 1)
        tree_addr >>= tree_ht  # h'
        # Loop for next higher subtree...

    # At the end the final root value MUST equal PK.root
    DPRINT(f"Final root={wots_input}")
    DPRINT(f"Expecting ={PKroot}")

    # Expecting wots_input to equal PK.root
    if wots_input.lower() != PKroot.lower():
        print("ERROR: Final root of WOTS+ tree must equal PK.root")
        print(f"{wots_input.lower} vs {PKroot.lower()}")

    # Print out and check the signature
    sig_lines = 1 + (params.k * (1 + params.a)) + (params.d * (wots_len + tree_ht))
    sig_bytes = sig_lines * params.n
    # print(f"sig:\n{sig}")
    DPRINT(f"sig lines = {sig.count('\n')} (expecting {sig_lines})")
    # Strip sig string down to pure hex digits
    sighex = sig
    sighex = re.sub(r'\s+#.*?$', '', sighex, flags=re.MULTILINE)
    sighex = sighex.replace('\n', '')
    DPRINT(f"sig bytes = {len(sighex) // 2} (expecting {sig_bytes})")

    # Compute HASH of signature as a check - use either SHA-256 or SHAKE
    if params.compr:  # True if using SHA-2
        hash_sig = hashfuncs.sha256(sighex)
        sighashalg = "sha256"
    else:  # using SHAKE128
        hash_sig = hashfuncs.shake128_256((sighex))
        sighashalg = "SHAKE128/256"
    print(f"HASH(sig)={hash_sig} {sighashalg}")
    if sig_hash_ok:
        print(f"Expected ={sig_hash_ok}")
        assert hash_sig == sig_hash_ok

    return sighex, sig  # The signature as one long line of hex chars + annotated signature


def slhdsa_sign(params, msg, sk, addrnd='', ctx='', internal=False, hash_sig=''):
    """SLH-DSA external sign. Pure signature only."""
    if internal:
        return slhdsa_sign_internal(sk, msg, addrnd, params)
    ctxlen = 0
    if ctx:
        b = bytes.fromhex(ctx)  # Check valid hex string, else ValueError
        ctxlen = len(ctx) // 2
        assert ctxlen <= 255
    # M' = toByte(0, 1) || toByte(|ctx|, 1) || ctx || M
    m = "{:02x}".format(0) + "{:02x}".format(ctxlen) + ctx + msg
    return slhdsa_sign_internal(sk, m, addrnd, params)


def save_sig_to_file(algname, sighex, sigannotated):
    # Write sig to a file in tests subfolder
    fname = "./tests/sig-" + algname + "-out.txt"
    with open(fname, 'w') as f: f.write(sigannotated)
    print("Created file", fname)

    fname = fname + ".raw.txt"
    with open(fname, 'w') as f: f.write(sighex)
    print("Created file", fname)