# $Id: discrete_gaussian_zq.py $
# $Date: 2024-02-12 13:57Z $

import math
import random

"""A discrete Gaussian sampler over the integers mod q.

Usage::

    import discrete_gaussian_zq as Dgi
    dgi = Dgi(q)
    e = [dgi.D() for x in range(N)]

This implementation only has a centre ``c = 0``.

Equivalent to Sage [Ref 2, p9]::

   import sage.stats.distributions.discrete_gaussian_integer as dgi
   def sample_noise(N, R):
       D = dgi.DiscreteGaussianDistributionIntegerSampler(sigma=1.0)
       return vector([R(D()) for i in range(N)])
   R = Integers(q)
   e = sample_noise(N, R)

Refs:
[1] Sage 10.0 Reference manual, "Discrete Gaussian Samplers over the Integers"
https://doc.sagemath.org/html/en/reference/stats/sage/stats/distributions/discrete_gaussian_integer.html

[2] Yang Li and Kee Siong Ng and Michael Purcell,
"A Tutorial Introduction to Lattice-based Cryptography and Homomorphic Encryption",
2022, arXiv:2208.08125 [cs.CR], https://doi.org/10.48550/arXiv.2208.08125
"""

# ****************************** LICENSE ***********************************
# Copyright (C) 2023-24 David Ireland, DI Management Services Pty Limited.
# All rights reserved. <www.di-mgt.com.au> <www.cryptosys.net>
# The code in this module is licensed under the terms of the MIT license.
# @license MIT
# For a copy, see <http://opensource.org/licenses/MIT>
# **************************************************************************

__version__ = "1.0.0"


class Dgi():
    """Discrete Gaussian sampler over the integers mod q."""
    def __init__(self, q, sigma=1.0, tau=6.0):
        """Construct a new sampler for a discrete Gaussian distribution.

        Args:
            q (int): modulus, an integer greater than 1
            sigma (float): standard deviation
            tau (float): samples outside the range [-tau * sigma, tau * sigma]
            are considered to have probability zero.
        """
        # Centre c is always 0.
        self.q = q
        self.sigma = sigma
        # Compute scale and integer bound
        self.scale = 1 / (sigma * math.sqrt(2 * math.pi))
        self.bound = math.floor(tau * sigma)
        self.fmax = self.f(0)
        # Set up table in integer range [-bound,bound)
        # NB bound+1 for range() upper limit
        self.tab = [self.f(x)
                    for x in range(-self.bound, self.bound+1)]

    def f(self, x):
        """Gaussian probability density function, ``f(x)``."""
        return self.scale * math.exp(-x*x/(2*self.sigma*self.sigma))

    def D(self):
        """Return a sample in the range [0,q-1]."""
        # Use rejection sampling
        '''
        do {
           select integer x in range [-bound, bound)
           select y in range [0.0, fmax)
        } while y > f(x)
        return x mod q
        '''
        while True:
            # NB randint(a,b) has range(a, b+1)
            x = random.randint(-self.bound, self.bound)
            y = random.random() * self.fmax
            # if y > self.f(x):
            if y > self.tab[x + self.bound]:
                continue
            else:
                break
        # Return x mod q
        return x + self.q if x < 0 else x


def main():
    # Tests
    dgi = Dgi(31)
    e1 = [dgi.D() for x in range(40)]
    print("e1 =", e1)
    print("e1 =", sorted(e1))
    print("value:count:", dict((x, e1.count(x)) for x in sorted(set(e1))))
    dgi = Dgi(31, 3.0)
    e3 = [dgi.D() for x in range(60)]
    print("e3 =", e3)
    print("e3 =", sorted(e3))
    print("value:count:", dict((x, e3.count(x)) for x in sorted(set(e3))))


if __name__ == "__main__":
    main()