/* $Id: t_dibigdRSAKeyGen.cpp $ */

/*
* Copyright (C) 2026 David Ireland, D.I. Management Services Pty Limited
* <https://di-mgt.com.au/contact/> <https://di-mgt.com.au/bigdigits.html>
* SPDX-License-Identifier: MPL-2.0
*
* Last updated:
* $Date: 2026-05-05 02:02 $
* $Revision: 1.0.1 $
* $Author: dai $
*/

/* Generate a new 1025-bit RSA key and test that RSA algorithm works on random data */

#ifdef NDEBUG
#undef NDEBUG
#endif
#include <iostream>
//#include <string>
#include <stdexcept>
#include <assert.h>
#include <time.h>
#include "dibigd.hpp"
using std::cout;
using std::endl;
using namespace dibigd;

void do_test() {
    cout << "Generate a new 1025-bit RSA key and test that RSA algorithm works on random data." << endl;
    cout << "=================================================================================" << endl;
    
    size_t nbits = 1025;  // Pick a power of two plus 1 to check boundary edge cases
    cout << "Generating a " << nbits << "-bit RSA key" << endl;
    BigDigit e(0x10001);    // Fixed value 2^16 + 1
    e.print("e=");
    /* Generate (p,q) in two halves, approx equal */
    size_t pbits = nbits / 2;
    size_t qbits = nbits - pbits;
    /* Compute two primes p,q of required length with p mod e > 1 then check that n=pq is of required nbits length, else repeat */
    BigDigit p, q, n;
    do {
        cout << "Computing p (be patient)..." << endl;
        do {
            p.set_prime(pbits);    // NB this takes some time
            p.printhex("p=");
        } while (p % e == 1);
        cout << "p is " << p.bitlen() << " bits" << endl;
        cout << "Computing q (be patient)..." << endl;
        do {
            q.set_prime(qbits);    // NB this takes some time
            q.printhex("q=");
        } while (q % e == 1);
        cout << "q is " << q.bitlen() << " bits" << endl;

        /* Compute n = pq */
        n = p * q;
        cout << "n is " << n.bitlen() << " bits" << endl;
        // Half the time we get the correct number of bits, otherwise try again
    } while (n.bitlen() != nbits);

    n.printhex("Final n=");

    /* If q > p swap p and q so p > q 
       - we need this to compute the CRT key values */
    if (q > p) {
        cout << "Swopping p and q" << endl;
        // Swop using XOR
        q = q ^ p;
        p = q ^ p;
        q = q ^ p;
    }
    p.printhex("p=");
    q.printhex("q=");

    /* This value of n should already comply with requirement that gcd((p-1)(q-1), e) == 1 
    (because we checked that (p,q) mod e > 1 when generating p and q)
    but we'll check anyway.
    */
    BigDigit phi = (p - 1) * (q - 1);
    phi.printhex("phi=");
    (phi.gcd(e)).print("gcd(phi,e)=");
    assert(phi.gcd(e) == 1);

    /* Compute inverse of e modulo phi: d = 1/e mod (p-1)(q-1) */
    BigDigit d = e.mod_inv(phi);
    d.printhex("d=");

    /* Check ed = 1 mod phi */
    assert((e * d) % phi == 1);

    /* Calculate CRT key values (dP, dQ, qInv) */
    cout << "CRT values:" << endl;
    BigDigit dP = e.mod_inv(p - 1);
    BigDigit dQ = e.mod_inv(q - 1);
    BigDigit qInv = q.mod_inv(p);
    dP.printhex("dP=");
    dQ.printhex("dQ=");
    qInv.printhex("qInv=");

    /* Do some checks that this RSA key performs the basic encryption and signing algorithms 
    on a random message m < n.
    CAUTION: this is NOT how to do it practice like PKCS-v1_5. 
    This just tests the mathematical properties.
    */
    cout << "Compute c = m^e mod n for random m < n..." << endl;
    // Generate random value m < n
    BigDigit m;
    m.set_rand_number(n);
    m.printhex("m=");
    // Encrypt c = m^e mod n (NB m is not a valid PKCS-v1_5 encryption block)
    BigDigit c = m.mod_exp(e, n);
    c.printhex("c=");
    // Decrypt m' = c^d mod n
    clock_t start = clock();
    BigDigit m1 = c.mod_exp(d, n);
    clock_t finish = clock();
    m1.printhex("m'=");
    assert(m1 == m);
    cout << "OK, successfully decrypted m." << endl;
    double interval = (double)(finish - start) / CLOCKS_PER_SEC;
    cout << "Decryption by inversion took " << interval << " seconds" << endl;

    /* Sign s = m^d mod n (NB m is not a valid PKCS-v1_5 signature block) */
    cout << "Compute signature s = m^d mod n..." << endl;
    BigDigit s = m.mod_exp(d, n);
    s.printhex("s=");
    /* Check verify m' = s^e mod n */
    m1 = s.mod_exp(e, n);
    m1.printhex("m'=");
    assert(m1 == m);
    cout << "OK, successfully verified signature over m." << endl;

    /* Now decrypt using CRT method */
    m1 = 0;
    cout << "Decrypt using CRT method..." << endl;
    c.printhex("Input c=");
    start = clock();
    /* Let m_1 = c^dP mod p. */
    BigDigit m_1 = c.mod_exp(dP, p);
    /* Let m_2 = c^dQ mod q. */
    BigDigit m_2 = c.mod_exp(dQ, q);
    /* Let h = qInv ( m_1 - m_2 ) mod p. */
    BigDigit h = qInv.mod_mult(m_1.mod_sub(m_2, p), p);
    /* Let m = m_2 + hq. */
    m1 = m_2 + h * q;
    finish = clock();
    m1.printhex("m'=");
    assert(m1 == m);
    cout << "OK, successfully decrypted m using CRT method." << endl;
    interval = (double)(finish - start) / CLOCKS_PER_SEC;
    cout << "Decryption by CRT took " << interval << " seconds" << endl;
}

int main() {
    /* MSVC memory leak checking stuff */
#if _MSC_VER >= 1100
    _CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF);
    _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE);
    _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDOUT);
    _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE);
    _CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDOUT);
#endif

    /* Catch any exceptions */
    try {
        do_test();
    }
    catch (const std::exception& e) {
        // Handle standard exceptions with a specific message
        std::cerr << "Standard exception: " << e.what() << std::endl;
    } 
    catch (...) {
        std::cerr << "Caught unknown exception." << std::endl;
    }

    cout << endl << "ALL DONE." << endl << endl;

    return 0;
}