# $Id: lattice-lwe-simple-fixed.py $
# $Date: 2024-02-24 12:57Z $

import matrixzq as mq

"""
The same as ``lattice-lwe-simple.py`` but with random values we prepared earlier and fixed for this example.
"""

# ****************************** LICENSE ***********************************
# Copyright (C) 2024 David Ireland, DI Management Services Pty Limited.
# All rights reserved. <www.di-mgt.com.au> <www.cryptosys.net>
# The code in this module is licensed under the terms of the MIT license.
# @license MIT
# For a copy, see <http://opensource.org/licenses/MIT>
# **************************************************************************

# Debugging stuff
DEBUG = False  # Set to True to show debugging output
DPRINT = print if DEBUG else lambda *a, **k: None

# Parameters
q = 31  # Modulus
n = 4   # Security parameter
N = 7   # Sample size
# \chi noise distribution over Z_q

mq.set_modulus(q)

# GenKey:
# -------------------------------------------
# Secret key: sample a random matrix s <-- Z_q^n
s=mq.new_vector([23, 8, 25, 6])
print("s =", mq.sprint_vector(s))

# Public key: sample a random matrix A <-- Z_q^{N \times n}
A = mq.new_matrix([
[23, 30, 26, 21],
[17, 9, 24, 13],
[11, 19, 26, 19],
[22, 7, 29, 6],
[21, 21, 13, 14],
[14, 28, 20, 26],
[16, 26, 21, 20]])
print(f"A:\n{mq.sprint_matrix(A)}")
if DEBUG: mq.print_matrix_latex(A)

# Random noise vector e <-- \chi^N
e = mq.new_vector([1, 0, 0, 0, 1, 30, 1])
print("e =", mq.sprint_vector(e))
if DEBUG: mq.print_matrix_latex(e)

# Compute b = As + e
b = mq.multiply(A, s)
b = mq.add(b, e)
print("b =", mq.sprint_vector(b))
if DEBUG: mq.print_matrix_latex(b)

# Enc: c <--Enc(pk, m)
# ----------------------
print("Enc:")
# Encrypt the message m \in {0,1} by computing
# r = {0,1}^N // r <-- Z_2^N
# u = A^T*r, v = b dot r + [q/2]*m
m = 1
print("m =", m)

# Random vector r of length N from {0,1}
r = mq.new_vector([0, 1, 0, 1, 1, 0, 0])
print("r =", mq.sprint_vector(r))
if DEBUG: mq.print_matrix_latex(r)

# u = A^T*r
if DEBUG: mq.print_matrix_latex(mq.transpose(A))

u = mq.multiply(mq.transpose(A), r)
print("u=A^T*r =", mq.sprint_vector(u))
qm2 = (q // 2) * m
print("floor(q/2)*m =", qm2)
v = mq.dotproduct(b, r)
print("b dot r =", v)
v = (v + qm2) % q
print("v =", v)
print(f"c = (u,v) =({mq.sprint_vector(u)}, {v})")

# Dec: m := Dec(sk, c)
# -----------------------
print("Dec:")
# Decrypt the message m \in {0,1} by computing
# v' = s^T*u
# d = v - v'
# m = roundint(2d/q) mod 2
v1 = mq.dotproduct(u, s)
print("v' = u dot s =", v1)
d = (v - v1) % q
print("d = v - v' =", d % q)
print(f"[<2d/q>] = [<2*{d}/{q}>] = {mq.roundfrac2int(2 * d, q)}")

decrypted_message = mq.roundfrac2int(2 * d, q) % 2
print("decrypted message =", decrypted_message)
print("original message  =", m)
assert (decrypted_message == m)


print("\nALL DONE.")