/* $Id: t_dibigdRSAKeyGen.cpp $ */
/*
* Copyright (C) 2026 David Ireland, D.I. Management Services Pty Limited
* <https://di-mgt.com.au/contact/> <https://di-mgt.com.au/bigdigits.html>
* SPDX-License-Identifier: MPL-2.0
*
* Last updated:
* $Date: 2026-05-05 02:02 $
* $Revision: 1.0.1 $
* $Author: dai $
*/
/* Generate a new 1025-bit RSA key and test that RSA algorithm works on random data */
#ifdef NDEBUG
#undef NDEBUG
#endif
#include <iostream>
//#include <string>
#include <stdexcept>
#include <assert.h>
#include <time.h>
#include "dibigd.hpp"
using std::cout;
using std::endl;
using namespace dibigd;
void do_test() {
cout << "Generate a new 1025-bit RSA key and test that RSA algorithm works on random data." << endl;
cout << "=================================================================================" << endl;
size_t nbits = 1025; // Pick a power of two plus 1 to check boundary edge cases
cout << "Generating a " << nbits << "-bit RSA key" << endl;
BigDigit e(0x10001); // Fixed value 2^16 + 1
e.print("e=");
/* Generate (p,q) in two halves, approx equal */
size_t pbits = nbits / 2;
size_t qbits = nbits - pbits;
/* Compute two primes p,q of required length with p mod e > 1 then check that n=pq is of required nbits length, else repeat */
BigDigit p, q, n;
do {
cout << "Computing p (be patient)..." << endl;
do {
p.set_prime(pbits); // NB this takes some time
p.printhex("p=");
} while (p % e == 1);
cout << "p is " << p.bitlen() << " bits" << endl;
cout << "Computing q (be patient)..." << endl;
do {
q.set_prime(qbits); // NB this takes some time
q.printhex("q=");
} while (q % e == 1);
cout << "q is " << q.bitlen() << " bits" << endl;
/* Compute n = pq */
n = p * q;
cout << "n is " << n.bitlen() << " bits" << endl;
// Half the time we get the correct number of bits, otherwise try again
} while (n.bitlen() != nbits);
n.printhex("Final n=");
/* If q > p swap p and q so p > q
- we need this to compute the CRT key values */
if (q > p) {
cout << "Swopping p and q" << endl;
// Swop using XOR
q = q ^ p;
p = q ^ p;
q = q ^ p;
}
p.printhex("p=");
q.printhex("q=");
/* This value of n should already comply with requirement that gcd((p-1)(q-1), e) == 1
(because we checked that (p,q) mod e > 1 when generating p and q)
but we'll check anyway.
*/
BigDigit phi = (p - 1) * (q - 1);
phi.printhex("phi=");
(phi.gcd(e)).print("gcd(phi,e)=");
assert(phi.gcd(e) == 1);
/* Compute inverse of e modulo phi: d = 1/e mod (p-1)(q-1) */
BigDigit d = e.mod_inv(phi);
d.printhex("d=");
/* Check ed = 1 mod phi */
assert((e * d) % phi == 1);
/* Calculate CRT key values (dP, dQ, qInv) */
cout << "CRT values:" << endl;
BigDigit dP = e.mod_inv(p - 1);
BigDigit dQ = e.mod_inv(q - 1);
BigDigit qInv = q.mod_inv(p);
dP.printhex("dP=");
dQ.printhex("dQ=");
qInv.printhex("qInv=");
/* Do some checks that this RSA key performs the basic encryption and signing algorithms
on a random message m < n.
CAUTION: this is NOT how to do it practice like PKCS-v1_5.
This just tests the mathematical properties.
*/
cout << "Compute c = m^e mod n for random m < n..." << endl;
// Generate random value m < n
BigDigit m;
m.set_rand_number(n);
m.printhex("m=");
// Encrypt c = m^e mod n (NB m is not a valid PKCS-v1_5 encryption block)
BigDigit c = m.mod_exp(e, n);
c.printhex("c=");
// Decrypt m' = c^d mod n
clock_t start = clock();
BigDigit m1 = c.mod_exp(d, n);
clock_t finish = clock();
m1.printhex("m'=");
assert(m1 == m);
cout << "OK, successfully decrypted m." << endl;
double interval = (double)(finish - start) / CLOCKS_PER_SEC;
cout << "Decryption by inversion took " << interval << " seconds" << endl;
/* Sign s = m^d mod n (NB m is not a valid PKCS-v1_5 signature block) */
cout << "Compute signature s = m^d mod n..." << endl;
BigDigit s = m.mod_exp(d, n);
s.printhex("s=");
/* Check verify m' = s^e mod n */
m1 = s.mod_exp(e, n);
m1.printhex("m'=");
assert(m1 == m);
cout << "OK, successfully verified signature over m." << endl;
/* Now decrypt using CRT method */
m1 = 0;
cout << "Decrypt using CRT method..." << endl;
c.printhex("Input c=");
start = clock();
/* Let m_1 = c^dP mod p. */
BigDigit m_1 = c.mod_exp(dP, p);
/* Let m_2 = c^dQ mod q. */
BigDigit m_2 = c.mod_exp(dQ, q);
/* Let h = qInv ( m_1 - m_2 ) mod p. */
BigDigit h = qInv.mod_mult(m_1.mod_sub(m_2, p), p);
/* Let m = m_2 + hq. */
m1 = m_2 + h * q;
finish = clock();
m1.printhex("m'=");
assert(m1 == m);
cout << "OK, successfully decrypted m using CRT method." << endl;
interval = (double)(finish - start) / CLOCKS_PER_SEC;
cout << "Decryption by CRT took " << interval << " seconds" << endl;
}
int main() {
/* MSVC memory leak checking stuff */
#if _MSC_VER >= 1100
_CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF);
_CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE);
_CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDOUT);
_CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE);
_CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDOUT);
#endif
/* Catch any exceptions */
try {
do_test();
}
catch (const std::exception& e) {
// Handle standard exceptions with a specific message
std::cerr << "Standard exception: " << e.what() << std::endl;
}
catch (...) {
std::cerr << "Caught unknown exception." << std::endl;
}
cout << endl << "ALL DONE." << endl << endl;
return 0;
}