/* $Id: bdcombinatorics.c $ */

/*
* Copyright (C) 2026 David Ireland, D.I. Management Services Pty Limited
* <https://di-mgt.com.au/contact/> <https://di-mgt.com.au/bigdigits.html>
* SPDX-License-Identifier: MPL-2.0
*
* Last updated:
* $Date: 2026-03-30 11:23:00 $
* $Revision: 1.0.1 $
* $Author: dai $
*/

/* Provides the following functions:
int bd_factorial(BIGD result, int n);  // Computes n! (n must be >= 0)
int bd_binomial(BIGD result, int n, int k); // Computes n choose k (n >= k >= 0)
int bd_permutations(BIGD result, int n, int k); // Computes n permute k (n >= k >= 0)
*/

#include <stdio.h>
#include <time.h>
#include <assert.h>
#include "bigd.h"
#include "bdcombinatorics.h"

#define NO_DPRINTF  /* <= Comment out to turn on debugging */
#if (defined(_DEBUG) && !(defined(NO_DPRINTF)))
#define DPRINTF1(s, a1) printf(s, a1)
#define BDPRDECIMAL(pre,b,suf) bdPrintDecimal((pre), (b), (suf))
#else
#define DPRINTF1(s, a1) 
#define BDPRDECIMAL(pre,b,suf)
#endif

// Internal fn called recursively by factorial()
static void product(BIGD pr, BIGD a, BIGD b, int depth) 
{
    BIGD d, m, t, s;
    assert(bdCompare(a, b) >= 0);
    d = bdNew();
    m = bdNew();
    t = bdNew();
    s = bdNew();
    DPRINTF1("depth=%d\n", depth);
    // d = a - b
    bdSubtract(d, a, b);
    BDPRDECIMAL("a=", a, ", ");
    BDPRDECIMAL("b=", b, ", ");
    BDPRDECIMAL("d=", d, "\n");
    // if d == 0 return 1
    if (bdShortIsEqual(d, 0)) {
        bdSetShort(pr, 1);
        goto clean_up;
    }
    // if d == 1 return a
    else if (bdShortIsEqual(d, 1)) {
        bdSetEqual(pr, a);
        goto clean_up;
    }
    // if d == 2 return a * (a - 1)
    else if (bdShortIsEqual(d, 2)) {
        bdShortSub(t, a, 1); // t = a - 1
        bdMultiply(pr, a, t); // pr = a * t = a * (a - 1)
        goto clean_up;
    }
    // if d == 3 return a * (a - 1) * (a - 2)
    else if (bdShortIsEqual(d, 3)) {
        bdShortSub(t, a, 1);    // t = a - 1
        bdMultiply(s, a, t);    // s = a * t = a * (a - 1)
        bdShortSub(t, a, 2);    // t = a - 2
        bdMultiply(pr, t, s);    // pr = t * s = (a - 2) * (a * (a - 1))
        goto clean_up;
    }
    // m = (a + b) // 2
    bdAdd(t, a, b);    // t = a + b
    // [v1.0.1] Use bdShortIntDiv
    bdShortIntDiv(m, t, 2); // m = t / 2
    BDPRDECIMAL("m=", m, "\n");
    // return pr(a, m) * pr(m, b)
    product(t, a, m, depth + 1);
    BDPRDECIMAL("product-sub1=", t, "\n");
    product(s, m, b, depth + 1);
    BDPRDECIMAL("product-sub2=", s, "\n");
    bdMultiply(pr, t, s);
    BDPRDECIMAL("product returns ", pr, "\n");

clean_up:
    bdFree(&d);
    bdFree(&m);
    bdFree(&t);
    bdFree(&s);
}

/** Compute factorial(n) = n! 
@param[out] result To receive result or zero on error
@param[in] n Input integer >= 0
@return 0 on success or nonzero if error.
@code 
BIGD result;
result = bdNew();
bd_factorial(result, 52);
bdPrintDecimal("", result, "\n");
// 80658175170943878571660636856403766975289505440883277824000000000000
bdFree(&result);
@endcode
*/
int bd_factorial(BIGD result, int n)
{
    BIGD N, Z;
    if (n < 0) {
        bdSetZero(result);
        return 1;
    }
    N = bdNew();
    Z = bdNew();
    bdSetShort(N, n);
    bdSetZero(Z);
    // factorial(n) = pr(n, 0)
    product(result, N, Z, 1);
    bdFree(&Z);
    bdFree(&N);
    return 0;
}

/** Compute binomial(n, k) = nCk ``n choose k``
@param[out] result To receive result or zero on error
@param[in] n Input integer >= 0
@param[in] k Input integer n >= k >= 0
@return 0 on success or nonzero if error.
@code
BIGD result;
result = bdNew();
bd_binomial(result, 52,13);
bdPrintDecimal("", result, "\n");
// 635013559600
bdFree(&result);
@endcode
*/
int bd_binomial(BIGD result, int n, int k)
{
    int i;
    BIGD u, q, N;

    if (k > n || k < 0) {
        bdSetZero(result);
        return 1;
    }

    u = bdNew();
    q = bdNew();
    N = bdNew();
    if (k > n - k) {
        k = n - k;    // Use symmetry
    }
    // Local BIGD variable
    bdSetShort(N, n);
    BDPRDECIMAL("N=", N, "\n");
    DPRINTF1("k=%d\n", k);

    // result = 1
    bdSetShort(result, 1);
    BDPRDECIMAL("result=", result, "\n");
    // for i in range(k):
    for (i = 0; i < k; i++) {
        //  result = result * (n - i) // (i + 1)
        DPRINTF1("i=%d\n", i);
        bdShortSub(u, N, i);  // u = n - i
        BDPRDECIMAL("u=n-i=", u, "\n");
        bdMultiply_s(result, u, result);  // result = result * u = result * (n-i)
        BDPRDECIMAL("result=result*(n-i)=", result, "\n");
        // [v1.0.1] Use bdShortIntDiv
        bdShortIntDiv(result, result, i + 1);    // result = result//(i+1)
        BDPRDECIMAL("result=", result, "\n");
    }
    // return result

    bdFree(&u);
    bdFree(&q);
    bdFree(&N);
    return 0; // OK
}

/** Compute permutations(n, k) = nPk ``n permute k``
@param[out] result To receive result or zero on error
@param[in] n Input integer >= 0
@param[in] k Input integer n >= k >= 0
@return 0 on success or nonzero if error.
@code
BIGD result;
result = bdNew();
bd_permutations(result, 52,13);
bdPrintDecimal("", result, "\n");
// 3954242643911239680000
bdFree(&result);
@endcode
*/
int bd_permutations(BIGD result, int n, int k)
{
    int i;
    BIGD t, N;

    if (k > n || k < 0) {
        bdSetZero(result);
        return 1;
    }

    t = bdNew();
    N = bdNew();

    // Local BIGD variable
    bdSetShort(N, n);
    BDPRDECIMAL("N=", N, "\n");
    DPRINTF1("k=%d\n", k);

    // result = 1
    bdSetShort(result, 1);
    BDPRDECIMAL("result=", result, "\n");
    // for i in range(k):
    for (i = 0; i < k; i++) {
        //  result *= (n - i)
        DPRINTF1("i=%d\n", i);
        bdShortSub(t, N, i);  // t = n - i
        BDPRDECIMAL("t=n-i=", t, "\n");
        bdMultiply_s(result, t, result);  // result = result * t = result * (n-i)
        BDPRDECIMAL("result=result*(n-i)=", result, "\n");
    }
    // return result

    bdFree(&t);
    bdFree(&N);
    return 0;
}