/* $Id: t_dibigdPoly1305.cpp $ */

/*
* Copyright (C) 2026 David Ireland, D.I. Management Services Pty Limited
* <https://di-mgt.com.au/contact/> <https://di-mgt.com.au/bigdigits.html>
* SPDX-License-Identifier: MPL-2.0
*
* Last updated:
* $Date: 2026-05-05 02:02 $
* $Revision: 1.0.1 $
* $Author: dai $
*/

/* Compute a Poly1305 authentication tag.
Ref: RFC8439 "ChaCha20 and Poly1305 for IETF Protocols"
Nir & Langley, June 2018. Section 2.5.2
<https://www.rfc-editor.org/rfc/rfc8439.html#section-2.5.2>
*/

#ifdef NDEBUG
#undef NDEBUG
#endif
#include <iostream>
#include <string>
#include <iomanip>
#include <stdexcept>
#include <vector>
#include <assert.h>
#include <time.h>
#include "dibigd.hpp"
using std::cout;
using std::endl;
using namespace dibigd;

#define MIN(a,b) (((a) < (b)) ? (a) : (b))

void do_test() {
/*
The inputs to Poly1305 are:
* A 256-bit one-time key
* An arbitrary length message
The output is a 128-bit tag.
*/
    cout << "Compute a Poly1305 authentication tag as per RFC 8439." << endl;
    cout << "======================================================" << endl;
    // Assume that we got the following keying material:
    // 256-bit key in network order
    BigDigit k("0x85d6be7857556d337f4452fe42d506a80103808afb0db2fd4abff6af4149f51b");
    k.printhex("k=");
    // Message to be Authenticated in network order (be careful about leading zeros in message - they count!)
    std::vector<unsigned char> msgvec = Bvec::from_hex("0x43727970746f6772617068696320466f72756d2052657365617263682047726f7570");
    size_t mbytes = msgvec.size();
    cout << "msg length in bytes " << mbytes << endl;
    Bvec::print_hex(msgvec);

    // Set the constant prime "P" to be 2^130-5:

    /* WARNING: if you do this:
        BigDigit P = (1 << 130) - 5; // Interprets (1 << 130) in usual 32-bit arithmetic
    // and you get warning C4293: '<<' : shift count negative or too big, undefined behavior.
    // So do the following instead:
    */
    BigDigit P = (BigDigit(1) << 130) - 5;  // Make sure '<<' operates on a BigDigit

    P.print("P=");    // 1361129467683753853853498429727072845819
    P.printhex();    // 0x3fffffffffffffffffffffffffffffffb
    // set a 128-bit mask of 1's - we'll use this below
    BigDigit mask128 = ((BigDigit(1) << 128) - 1);
    // split k into(r, s)
    BigDigit s = k & mask128;
    s.printhex("s in network order:  ");
    // Set s as 128-bit little-endian number
    s = s.reverse_octets(128 / 8);
    s.printhex("s as 128-bit number: ");
    cout << "(Correct s :         0x1bf54941aff6bf4afdb20dfb8a800301)" << endl;
    // r is the left-hand 128 bits
    BigDigit r = (k >> 128) & mask128;
    r.printhex("r=");
    // Set r as 128-bit little-endian number
    r = r.reverse_octets(128 / 8);
    // Clamp r
    r = r & (BigDigit("0x0ffffffc0ffffffc0ffffffc0fffffff"));
    r.printhex("r after clamping:    ");
    cout << "(Correct r         : 0x806d5400e52447c036d555408bed685)" << endl;

    BigDigit msg = from_octets(msgvec);
    // Reverse order of message bytes then take in blocks of 128 bits
    msg = msg.reverse_octets(mbytes);
    size_t nblocks = (mbytes + 15) / 16;
    size_t nleft = mbytes;
    BigDigit acc;
    // Loop through each block of 16 bytes (128 bits)
    for (size_t i = 0; i < nblocks; i++) {
        // # Get next 16 bytes from RHS
        BigDigit block = msg & mask128;
        // Add leading 0x01 byte
        size_t blklen = MIN(nleft, 16);
        nleft = nleft - blklen;
        block = (BigDigit("0x01") << (blklen)* 8) | block;
        block.printhex("Block with 0x01 byte =");
        // Main calc: a += n; a = (r * a) % p
        acc = (acc + block).mod_mult(r, P);
        // Shift message block by 16 bytes
        msg >>= 128;
    }
    BigDigit tag = acc + s;
    tag.print("tag as LE number: ");
    // 905406785994486245610219399192143267496
    // Serialize to get the tag as a 16-byte octet string, i.e. reverse order of bytes
    tag = tag.reverse_octets(16);
    tag.printhex("Tag:     ");
    cout << "Correct: 0xa8061dc1305136c6c22b8baf0c0127a9" << endl;
    assert(tag == "0xa8061dc1305136c6c22b8baf0c0127a9");
    // Serialize tag to an octet string
    std::vector<unsigned char> TAG = tag.to_octets(16);
    Bvec::print_hex(TAG, ":");
}

int main() {
    /* MSVC memory leak checking stuff */
#if _MSC_VER >= 1100
    _CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF);
    _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE);
    _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDOUT);
    _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE);
    _CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDOUT);
#endif

    /* Catch any exceptions */
    try {
        do_test();
    }
    catch (const std::exception& e) {
        // Handle standard exceptions with a specific message
        std::cerr << "Standard exception: " << e.what() << std::endl;
    }
    catch (...) {
        std::cerr << "Caught unknown exception." << std::endl;
    }

    cout << endl << "ALL DONE." << endl << endl;

    return 0;
}