/* $Id: t_dibigdDHKeyExch.cpp $ */

/*
* Copyright (C) 2026 David Ireland, D.I. Management Services Pty Limited
* <https://di-mgt.com.au/contact/> <https://di-mgt.com.au/bigdigits.html>
* SPDX-License-Identifier: MPL-2.0
*
* Last updated:
* $Date: 2026-05-05 02:02 $
* $Revision: 1.0.1 $
* $Author: dai $
*/

/* Perform Diffie-Hellman key exchange.
Parties A and B exchange DH parameters (p,q,g).
A chooses random private key x in [1,q-2] and computes public key y=g^x mod p.
A sends y to party B.
B chooses random private key u in [1,q-2] and computes public key v=g^u mod p.
B sends v to party A.
A receives v and outputs common secret z = v^x mod p
B receives y and outputs common secret z = y^u mod p
Common secret is the same because z = v^x = (g^u)^x = (g^x)^u = y^u = z (mod p).
*/

#ifdef NDEBUG
#undef NDEBUG
#endif
#include <iostream>
#include <string>
#include <iomanip>
#include <stdexcept>
#include <vector>
#include <assert.h>
#include <time.h>
#include "dibigd.hpp"
using std::cout;
using std::endl;
using namespace dibigd;

/** Format a boolean value */
inline const std::string bool2str(bool b) {
    return b ? "true" : "false";
}

/* You'd do this once, the first time you received the parameters, then you'd trust them. */
bool validate_parameters(BigDigit p, BigDigit q, BigDigit g, size_t L, size_t N) {
    cout << "Validating parameters..." << endl;
    bool isprime;
    BigDigit chk;
    cout << "bitlen(p)=" << p.bitlen() << endl;
    if (p.bitlen() < L)
        return false;
    cout << "Checking that p is prime (takes time)..." << endl;
    isprime = p.is_prime();
    cout << "``p is prime`` is " << bool2str(isprime) << endl;
    if (!isprime)
        return false;
    cout << "bitlen(q)=" << q.bitlen() << endl;
    if (q.bitlen() < N) 
        return false;
    isprime = q.is_prime();
    cout << "``q is prime`` is " << bool2str(isprime) << endl;
    if (!isprime)
        return false;
    chk = (p - 1) % q;
    cout << "(p - 1) mod q = " << chk.to_str() << " (expecting 0)" << endl;
    if (chk != 0)
        return false;
    cout << "``g < p`` is " << bool2str(g < p) << endl;
    if (!(g < p))
        return false;
    chk = g.mod_exp(q, p);
    cout << "g^q mod p = " << chk.to_str() << " (expecting 1)" << endl;
    if (chk != 1)
        return false;
    // OK, we have succeeded
    return true;
}

void do_test() {
    cout << "Perform Diffie-Hellman key exchange." << endl;
    cout << "====================================" << endl;
    // validate_parameters takes some time, you may omit
    bool do_validate = false;
    size_t L = 1024;
    size_t N = 160;
    // Input parameters (p,q,g) = Some values we made earlier
    BigDigit p("0xf9f0867396bc9756f43d799dc8432bbb23a8bfdddbf73e4d43514de82c2770bf054042034af42d84f8c09b2a71899ae700f9a5f249409"
        "9eed4777e3ab4a16c69fd6bcf9c62805366f0a6a78aa93e772f86ee4f8b96807ea7397a7d9a608cf07fc675f401cee4e0ba01507c2f4a54cdb75d0"
        "9e3ccec093cc956cf907e78f06fd5");
    BigDigit q("0xc4d6773eb9fd7828b4c329f28d3637e5c592d5d1");
    BigDigit g("0x4e6397c931e3a7c84014ee0e168d396a680f0e687175353e6d799c63ab6213f90be18c33245c7161ae25977a03713dcc678b3ad3a0fde"
        "3a3f259bd64a1616d2e8c3fbbac0a58c0ad585bd4485261e68d31ee9b343ffb5ad90b3854e39d92b33f0ccb1a17b63cb6085d931970e7b02492924"
        "676f8d3655797580e6e2c781fd710");
    p.printhex("p="); q.printhex("q="); g.printhex("g=");

    // Optional check of parameters (only need to do once in practice)
    if (do_validate) {
        assert(validate_parameters(p, q, g, L, N));
    }

    cout << "A and B generate their private/public key pairs..." << endl;

    // Party A generates a private/public key pair (x,y)
    BigDigit x, y;
    x.set_rand_number(q - 1); // x in [2,(q-2)]
    assert(x > 1);    // chances of failure?
    x.printhex("A's private key x_A is ");
    y = g.mod_exp(x, p);
    y.printhex("A's public key y_A is ");

    // Party B generates a private/public key pair (u,v)
    BigDigit u, v;
    u.set_rand_number(q - 1); // u in [2,(q-2)]
    assert(u > 1);
    u.printhex("B's private key x_B is ");
    v = g.mod_exp(u, p);
    v.printhex("B's public key y_B is ");

    bool ok;
    BigDigit z;
    // Party B validates A's public key, y_A, and computes the common secret z
    ok = y > 1 && y < p;
    cout << "``1 < y_A <p`` is " << bool2str(ok) << endl;
    assert(ok);
    z = y.mod_exp(u, p);
    cout << "Common secret computed by B, Z_B=" << z.to_strhex() << endl;

    BigDigit o = z; // Remember for later

    // Party A validates B's public key, y_b, and computes the common secret z
    ok = v > 1 && v < p;
    cout << "``1 < y_B <p`` is " << bool2str(ok) << endl;
    assert(ok);
    z = v.mod_exp(x, p);
    cout << "Common secret computed by A, Z_A=" << z.to_strhex() << endl;

    assert(z == o);

    // Serialize z to an octet string
    size_t zbytes = (L + 7) / 8;
    std::vector<unsigned char> ZZ = z.to_octets(zbytes);
    Bvec::print_hex(ZZ, ":");
}

int main() {
    /* MSVC memory leak checking stuff */
#if _MSC_VER >= 1100
    _CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF);
    _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE);
    _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDOUT);
    _CrtSetReportMode(_CRT_ERROR, _CRTDBG_MODE_FILE);
    _CrtSetReportFile(_CRT_ERROR, _CRTDBG_FILE_STDOUT);
#endif

    /* Catch any exceptions */
    try {
        do_test();
    }
    catch (const std::exception& e) {
        // Handle standard exceptions with a specific message
        std::cerr << "Standard exception: " << e.what() << std::endl;
    }
    catch (...) {
        std::cerr << "Caught unknown exception." << std::endl;
    }

    cout << endl << "ALL DONE." << endl << endl;

    return 0;
}