# $Id: lattice-lwe-simple.py $
# $Date: 2024-02-24 12:47Z $

import matrixzq as mq
import discrete_gaussian_zq as dgz
import random

"""
A simple public-key encryption scheme based on the learning
with errors (LWE) problem [Regev, 2005].
"""

# ****************************** LICENSE ***********************************
# Copyright (C) 2024 David Ireland, DI Management Services Pty Limited.
# All rights reserved. <www.di-mgt.com.au> <www.cryptosys.net>
# The code in this module is licensed under the terms of the MIT license.
# @license MIT
# For a copy, see <http://opensource.org/licenses/MIT>
# **************************************************************************

# Debugging stuff
DEBUG = True  # Set to True to show debugging output
DPRINT = print if DEBUG else lambda *a, **k: None


def do_pke(q, n, N, sigma=1.0):
    """Perform PKE operations Gen/Enc/Dec`.

    Args:
        q (int): Modulus
        n (int): Security parameter
        N (int): Sample size
        sigma (float): sigma
    """

    print(f"q={q} n={n} N={N} sigma={sigma}")
    mq.set_modulus(q)

    # Gen: (pk, sk) <-- Gen(1^n)
    # -------------------------------------------
    DPRINT("Gen:")
    # Private key: sample a private key s <-- Z_q^n
    s = mq.new_vector([mq.random_element() for i in range(n)])
    DPRINT("s =", mq.sprint_vector(s))

    # Public key: sample a random matrix A <-- Z_q^{N x n}
    A = mq.new_matrix([[mq.random_element() for i in range(n)]
                       for j in range(N)])
    DPRINT(f"A:\n{mq.sprint_matrix(A)}")
    # Random noise vector e <-- \chi^N
    dgi = dgz.Dgi(q, sigma=sigma)
    e = mq.new_vector([dgi.D() for x in range(N)])
    DPRINT("e =", mq.sprint_vector(e))
    # Compute b = As + e
    b = mq.multiply(A, s)
    b = mq.add(b, e)
    DPRINT("b =", mq.sprint_vector(b))
    # (pk, sk) = ((A, b), s)

    # Enc: c <--Enc(pk, m)
    # ----------------------
    DPRINT("Enc:")
    # Encrypt the message m \in {0,1} by computing
    # r = {0,1}^N // r <-- Z_2^N
    m = random.randint(0, 1)
    DPRINT("m =", m)

    # Random vector r of length N from {0,1}
    r = mq.new_vector([random.randint(0, 1) for i in range(N)])
    DPRINT("r =", mq.sprint_vector(r))
    # u = A^T*r
    u = mq.multiply(mq.transpose(A), r)
    DPRINT("u=A^T*r =", mq.sprint_vector(u))
    qm2 = (q//2) * m
    DPRINT("floor(q/2)*m =", qm2)
    v = mq.dotproduct(b, r) + qm2
    DPRINT("v =", v)
    DPRINT(f"c = (u,v) =({mq.sprint_vector(u)}, {v})")

    # Dec: m := Dec(sk, c)
    # -----------------------
    DPRINT("Dec:")
    # Decrypt the message m \in {0,1} by computing
    # v' = s^T*u
    # d = v - v'
    # m = roundint(2d/q) mod 2
    v1 = mq.dotproduct(u, s)
    DPRINT("v' = u dot s =", v1)
    d = v - v1
    DPRINT("d = v - v' =", d % q)
    decrypted_message = mq.roundfrac2int(2 * d, q) % 2
    print("decrypted message =", decrypted_message)
    print("original message  =", m)
    assert (decrypted_message == m)


def main():
    # Parameters (deliberately small for illustrative purposes)
    q = 31  # Modulus
    n = 4   # Security parameter
    N = 7   # LWE sample size
    # \chi noise distribution over Z_q

    MQ_MIN_VER = '1.1.0'
    # print("matrixqz version =", mq.__version__)
    if (mq.__version__ < MQ_MIN_VER):
        raise RuntimeError("Require at least matrixqz version " + MQ_MIN_VER)
    do_pke(q, n, N)

    # do_pke(q = 655360001,
    # n = 1000,
    # N = 500)

    print("ALL DONE.")


if __name__ == "__main__":
    main()