# @file slh_hashfuncs.py 
# @version 1.1.0 (2026-02-15T08:23Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-25 DI Management Services Pty Ltd
# @license Apache-2.0

"""SLH-DSA hash functions tailored for parameter set. """

import hashlib_pure as hash
from slh_params import params


def BlockPad(PKseed, blockbytes=64):
    # Pad PK.seed to 64/128 bytes with zeros (NB in hex)
    return PKseed + "0" * (blockbytes * 2 - len(PKseed))


def H_msg_shake(R, PKseed, PKroot, M, params):
    return hash.SHAKE256(R+PKseed+PKroot+M, params.m)


# v3.1
def H_msg_sha256(R, PKseed, PKroot, M, params):
    return hash.MGF1_SHA256(R+PKseed+hash.SHA256(R+PKseed+PKroot+M), params.m)


# v3.1
def H_msg_sha512(R, PKseed, PKroot, M, params):
    return hash.MGF1_SHA512(R+PKseed+hash.SHA512(R+PKseed+PKroot+M), params.m)

# v3.1
def PRF_shake(PKseed, SKseed, adrs, params):
    return hash.SHAKE256(PKseed + adrs + SKseed, params.n)


# PRF same for all SHA-2 categories 1-5
def PRF_sha256(PKseed, SKseed, adrs, params):
    return hash.SHA256(BlockPad(PKseed) + adrs + SKseed)[:params.n * 2]


def PRF_msg_shake(SKprf, optrand, msg, params):
    return hash.SHAKE256(SKprf + optrand + msg, params.n)


def PRF_msg_sha256(SKprf, optrand, msg, params):
    return hash.HMAC_SHA256(SKprf, optrand + msg)[:params.n * 2]


def PRF_msg_sha512(SKprf, optrand, msg, params):
    return hash.HMAC_SHA512(SKprf, optrand + msg)[:params.n * 2]


def F_shake(PKseed, adrs, M, params):
    return hash.SHAKE256(PKseed + adrs + M, params.n)


# F same for all SHA-2 categories 1-5
def F_sha256(PKseed, adrs, M, params):
    return hash.SHA256(BlockPad(PKseed) + adrs + M)[:params.n * 2]


def H_shake(PKseed, adrs, M1, M2, params):
    return hash.SHAKE256(PKseed + adrs + M1 + M2, params.n)


def H_sha256(PKseed, adrs, M1, M2, params):
    return hash.SHA256(BlockPad(PKseed) + adrs + M1 + M2)[:params.n * 2]


def H_sha512(PKseed, adrs, M1, M2, params):
    return hash.SHA512(BlockPad(PKseed, 128) + adrs + M1 + M2)[:params.n * 2]


def T_len_shake(PKseed, adrs, M, params):
    return hash.SHAKE256(PKseed + adrs + M, params.n)


def T_len_sha256(PKseed, adrs, M, params):
    return hash.SHA256(BlockPad(PKseed) + adrs + M)[:params.n * 2]


def T_len_sha512(PKseed, adrs, M, params):
    return hash.SHA512(BlockPad(PKseed, 128) + adrs + M)[:params.n * 2]


# Export simple hash funcs using hex-encoded params and return values.
def sha256(msghex):
    return hash.SHA256(msghex)


def shake128_256(msghex):
    return hash.SHAKE128_256(msghex)