/* $Id: t_dibigdDHGen.cpp $ */

/*
* Copyright (C) 2026 David Ireland, D.I. Management Services Pty Limited
* <https://di-mgt.com.au/contact/> <https://di-mgt.com.au/bigdigits.html>
* SPDX-License-Identifier: MPL-2.0
*
* Last updated:
* $Date: 2026-05-05 02:02 $
* $Revision: 1.0.1 $
* $Author: dai $
*/

/* Generate Diffie-Hellman domain parameters (p,q,g)
where p is a prime of L bits and q is a prime of N bits,
q is a prime divisor of p-1; p = jq+1 with j>=2;
and g has order q mod p; i.e. g^q mod p = 1 for g>1.
*/

#ifdef NDEBUG
#undef NDEBUG
#endif
#include <iostream>
#include <string>
#include <iomanip>
#include <stdexcept>
#include <vector>
#include <assert.h>
#include <time.h>
#include "dibigd.hpp"
using std::cout;
using std::endl;
using namespace dibigd;


void do_test() {
    cout << "Generate Diffie-Hellman domain parameters (p,q,g)." << endl;
    cout << "==================================================" << endl;

    // NB should do L=2048 and N=256, but this is a demo, so we'll do it shorter and quicker.
    size_t L = 1024;
    size_t N = 160;
    BigDigit p, q, x;
    cout << "Generating p,q for L=" << L << " N=" << N << endl;
    q.set_prime(N);
    q.printhex("q=");
    // Note that p = x - (x mod (2*q)) + 1 => p is congruent to 1 mod 2q => 2q | (p-1)
    int i = 0;
    cout << "Generating p - this may take some time..." << endl;
    cout << "Iteration for p: " << endl;
    do {
        i++;
        cout << i << " ";
        x.set_rand_bits(L);
        p = x - (x % (q * 2)) + 1;
    } while (p.bitlen() != L || !p.is_prime());
    cout << endl << "Found suitable prime, p." << endl;
    p.printhex("p: ");
    cout << "Generate generator, g..." << endl;
    BigDigit e = (p - 1) / q;
    e.printhex("e=");
    i = 0;
    BigDigit h, g, chk;
    bool isprime;
    do {
        i++;
        cout << "Iteration for g " << i << endl;
        h.set_rand_number(p - 1);
        g = h.mod_exp(e, p);
    } while (g == 1);
    g.printhex("g=");

    // Check values for (p, q, g)...
    cout << "Final values:" << endl;
    p.printhex("p=");
    cout << "p has length " << p.bitlen() << " bits. About to check if p is prime (this takes time)..." << endl;
    isprime = p.is_prime();
    cout << "``p is prime`` is " << std::boolalpha << (isprime) << endl;
    assert(isprime);
    q.printhex("q=");
    cout << "q has length " << q.bitlen() << " bits." << endl;
    isprime = q.is_prime();
    cout << "``q is prime`` is " << std::boolalpha << (isprime) << endl;
    assert(isprime);
    chk = (p - 1) % q;
    cout << "(p-1) mod q=" << chk.to_str() << " (expecting 0)" << endl;
    assert(chk == 0);
    g.printhex("g=");
    cout << "``g < p`` is " << std::boolalpha << (g < p) << endl;
    assert(g < p);
    chk = g.mod_exp(q, p);
    cout << "g^q mod p=" << chk.to_str() << " (expecting 1)" << endl;
    assert(chk == 1);
    // Output (p, q, g)


}

int main() {
    /* MSVC memory leak checking stuff */
#if _MSC_VER >= 1100
    _CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF);
    _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE);
    _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDOUT);
    _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE);
    _CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDOUT);
#endif

    /* Catch any exceptions */
    try {
        do_test();
    }
    catch (const std::exception& e) {
        // Handle standard exceptions with a specific message
        std::cerr << "Standard exception: " << e.what() << std::endl;
    }
    catch (...) {
        std::cerr << "Caught unknown exception." << std::endl;
    }

    cout << endl << "ALL DONE." << endl << endl;

    return 0;
}