# @file slh_adrs.py 
# @version 1.1.0 (2026-02-15T08:23Z)
# @author David Ireland <https://di-mgt.com.au/contact>
# @copyright 2023-26 DI Management Services Pty Ltd
# @license Apache-2.0

"""SLH-DSA ADRS Class."""

"""Usage:
    from slh_adrs import Adrs
    adrs = Adrs()
    adrs.setType(adrs.WOTS_HASH)
    print("adrs =", adrs.toHex(compress=True))
    adrs = Adrs.fromHex("1528daecdc86eb87610300000002000000020000000d")
    
Format (SHA256 compressed 22 bytes):
0028daecdc86eb87610300000006000000000000001b
layer   treeaddr  type    word1    word2    word3
[1]          [8]   [1]      [4]      [4]      [4]
00 28daecdc86eb8761 03 00000006 00000000 0000001b
0  1                9  10       14       18  # byte offsets
0  2                18 20       28       36  # hex offsets

Format (SHAKE 32 bytes)
0000000000000000006894ec35fadda8000000030000000d0000000000000ca2
layer[4] treeaddr[12]             type[4]  word1[4] word2[4] word3[4]
00000000 00000000006894ec35fadda8 00000003 0000000d 00000000 00000ca2
0        4                        16       20       24       28  # byte offsets
0        8                        32       40       48       56  # hex offsets

--------------------------------------------------------------------------
Type                    word1        word2      word3       Type constant      
-------------------------------------------------------------------------                 
0 WOTS+ hash addr       keypairaddr  chainaddr  hashaddr    WOTS_HASH
1 WOTS+ pub key compr   keypairaddr  0          0           WOTS_PK
2 Hash tree addr        0            tree ht    tree index  TREE
3 FORS tree addr        keypairaddr  tree ht    tree index  FORS_TREE  
4 FORS tree roots compr keypairaddr  0          0           FORS_ROOTS
5 WOTS+ key generation  keypairaddr  chainaddr  0           WOTS_PRF
6 FORS key generation   keypairaddr  0          tree_index  FORS_PRF
"""


class Adrs:
    """Class for SPHINCS+ ADRS."""
    def __init__(self, adrs_type=0, layer=0, treeaddr=0, word1=0, word2=0, word3=0):
        self.adrs_type = adrs_type
        self.layer = layer
        self.treeaddr = int(treeaddr)
        # Last 3 words have different meanings depending on type (see table above)
        # so we use generic names
        self.word1 = word1
        self.word2 = word2
        self.word3 = word3
    # Type constants
    WOTS_HASH = 0
    WOTS_PK = 1
    TREE = 2
    FORS_TREE = 3
    FORS_ROOTS = 4
    WOTS_PRF = 5
    FORS_PRF = 6


    def copy(self):
        """ Create a copy of this Adrs object."""
        newadrs = Adrs()
        newadrs.adrs_type = self.adrs_type
        newadrs.layer = self.layer
        newadrs.treeaddr = self.treeaddr
        newadrs.word1 = self.word1
        newadrs.word2 = self.word2
        newadrs.word3 = self.word3
        return newadrs


    def toHex(self, compressed):
        """Return ADRS in hex format.
        @compressed Set True for compressed SHA-2 format
        @remark To avoid bugs, MUST specify compressed=True or False
        """
        # Fix to ensure treeaddr is treated as a 64-bit number
        if compressed:
            # Ensure treeaddr is treated as a 64-bit number
            treeaddr_hex = format(self.treeaddr, f'x').zfill(16)
            return format(self.layer, f'02x') + treeaddr_hex \
                   + format(self.adrs_type, f'02x') + format(self.word1, f'08x') \
                   + format(self.word2, f'08x')+ format(self.word3, f'08x')
        else:
            treeaddr_hex = format(self.treeaddr, f'x').zfill(24)  # 12-byte address
            return format(self.layer, f'08x') + treeaddr_hex \
                   + format(self.adrs_type, f'08x') + format(self.word1, f'08x') \
                   + format(self.word2, f'08x') + format(self.word3, f'08x')


    def toHexSP(self, compressed=True):
        """Return ADRS in hex format with spaces."""
        if compressed:
            treeaddr_hex = format(self.treeaddr, f'x').zfill(16)
            return format(self.layer, f'02x') + ' ' + treeaddr_hex + ' ' \
                   + format(self.adrs_type, f'02x') + ' ' + format(self.word1, f'08x') \
                   + ' ' + format(self.word2, f'08x') + ' ' + format(self.word3, f'08x')
        else:
            treeaddr_hex = format(self.treeaddr, f'x').zfill(24)  # 12-byte address
            return format(self.layer, f'08x') + ' ' + treeaddr_hex + ' ' \
                   + format(self.adrs_type, f'08x') + ' ' + format(self.word1, f'08x') \
                   + ' ' + format(self.word2, f'08x') + ' ' + format(self.word3, f'08x')

    @classmethod
    def fromHex(cls, hexval):
        """Read in address in hex to new Adrs object"""
        if len(hexval) == 44:
            layer = int(hexval[:2], 16)
            treeaddr = int(hexval[2:18], 16)
            adrs_type = int(hexval[18:20], 16)
            word1 = int(hexval[20:28], 16)
            word2 = int(hexval[28:36], 16)
            word3 = int(hexval[36:44], 16)
        elif len(hexval) == 64:
            layer = int(hexval[:8], 16)
            treeaddr = int(hexval[8:32], 16)
            adrs_type = int(hexval[32:40], 16)
            word1 = int(hexval[40:48], 16)
            word2 = int(hexval[48:56], 16)
            word3 = int(hexval[56:64], 16)
        else:
            raise ValueError("Expected hex string of 44 or 64 chars.")
        return cls(adrs_type, layer, treeaddr, word1, word2, word3)

    def setType(self, adrs_type):
        self.adrs_type = adrs_type
        # Changing type initializes the subsequent 3 words to 0
        self.word1 = 0
        self.word2 = 0
        self.word3 = 0
        return self

    def setKeyPairAddress(self, kpa):
        self.word1 = kpa
        return self

    def getKeyPairAddress(self):
        return self.word1

    def setTreeHeight(self, ht):
        self.word2 = ht
        return self

    def getTreeHeight(self):
        return self.word2

    def setTreeIndex(self, idx):
        self.word3 = idx
        return self

    def getTreeIndex(self):
        return self.word3

    def setChainAddress(self, ca):
        self.word2 = ca
        return self

    def setHashAddress(self, ha):
        self.word3 = ha
        return self

    def setLayerAddress(self, la):
        self.layer = la
        return self

    def setTreeAddress(self, ta):
        self.treeaddr = int(ta)
        return self

    def getTreeAddress(self):
        return self.treeaddr



if __name__ == '__main__':
    # Create a new ADRS object
    adrs = Adrs()
    print(adrs.toHex(True))  # Empty, all zeros, 22 bytes
    print(adrs.toHex(compressed=False))  # Empty, all zeros, 32 bytes

    adrs = Adrs(3, 0, 0x28daecdc86eb8761, word3=27, word1=6)
    print(adrs.toHex(True))
    adrs.setType(4)
    print(adrs.toHex(True))
    adrs.setType(adrs.WOTS_HASH)
    print(adrs.toHex(True))
    hexval = adrs.toHex(True)
    print(hexval)
    print(adrs.fromHex(hexval).toHex(True))
    hexval = "1528daecdc86eb87610300000001000000030000000e"
    print(hexval)
    print(adrs.fromHex(hexval).toHex(True))
    adrs = adrs.fromHex(hexval)
    adrs.setType(adrs.TREE)
    print(adrs.toHex(True))
    adrs = Adrs.fromHex("1528daecdc86eb87610300000002000000020000000d")
    print(adrs.toHex(True))
    print(adrs.toHexSP(True))
    adrs = Adrs.fromHex("0000000000000000006894ec35fadda8000000030000000d0000000000000ca2")
    print(adrs.toHex(compressed=False))
    print(adrs.toHexSP(compressed=False))
    adrs = Adrs.fromHex("000000010000000000006894ec35fadd00000000000000a80000000000000002")
    print(adrs.toHex(compressed=False))
    print(adrs.toHexSP(compressed=False))
    newadrs = adrs.copy()
    print("newadrs...")
    print(newadrs.toHexSP(True))
    newadrs.setType(Adrs.FORS_PRF)
    print("NEW", newadrs.toHexSP(True))
    print("OLD", adrs.toHexSP(True))