// rabin.cpp - written and placed in the public domain by Wei Dai

#include "misc.h"
#include "asn.h"
#include "nbtheory.h"

#include "rabin.h"

RabinPublicKey::RabinPublicKey(const Integer &n, const Integer &r, const Integer &s)
    : n(n), r(r), s(s), modulusLen(n.ByteCount())
{
}

RabinPublicKey::RabinPublicKey(BufferedTransformation &bt)
{
    BERSequenceDecoder seq(bt);
    n.BERDecode(seq);
    modulusLen = n.ByteCount();
    r.BERDecode(seq);
    s.BERDecode(seq);
}

void RabinPublicKey::DEREncode(BufferedTransformation &bt) const
{
    DERSequenceEncoder seq(bt);
    n.DEREncode(seq);
    r.DEREncode(seq);
    s.DEREncode(seq);
}

void RabinPublicKey::RawEncrypt(const Integer &in, Integer &out) const
{
	out = in.Square()%n;
	if (in[0])
		out = out*r%n;
	if (Jacobi(in, n)==-1)
		out = out*s%n;
}

// *****************************************************************************
// private key operations:

RabinPrivateKey::RabinPrivateKey(const Integer &n, const Integer &r, const Integer &s,
                                 const Integer &p, const Integer &q, const Integer &u)
    : RabinPublicKey(n, r, s), p(p), q(q), u(u)
{
    assert(p*q==n);
	assert(Jacobi(r, p) == 1);
	assert(Jacobi(r, q) == -1);
	assert(Jacobi(s, p) == -1);
	assert(Jacobi(s, q) == 1);
    assert(u*q%p==1);
}

// generate a random private key
RabinPrivateKey::RabinPrivateKey(RandomNumberGenerator &rng, unsigned int keybits)
{
	assert(keybits >= 16);
    // generate 2 random primes of suitable size
	if (keybits%2==0)
	{
		const Integer minP = Integer(182) << (keybits/2-8);
		const Integer maxP = Integer::Power2(keybits/2)-1;
		p.Randomize(rng, minP, maxP, Integer::BLUMINT);
		q.Randomize(rng, minP, maxP, Integer::BLUMINT);
	}
	else
	{
		const Integer minP = Integer::Power2((keybits-1)/2);
		const Integer maxP = Integer(181) << ((keybits+1)/2-8);
		p.Randomize(rng, minP, maxP, Integer::BLUMINT);
		q.Randomize(rng, minP, maxP, Integer::BLUMINT);
	}

	boolean rFound=FALSE, sFound=FALSE;
	Integer t=2;
	while (!(rFound && sFound))
	{
		int jp = Jacobi(t, p);
		int jq = Jacobi(t, q);

		if (!rFound && jp==1 && jq==-1)
		{
			r = t;
			rFound = TRUE;
		}

		if (!sFound && jp==-1 && jq==1)
		{
			s = t;
			sFound = TRUE;
		}

		++t;
	}
    
    n = p * q;
	assert(n.BitCount() == keybits);
    modulusLen = n.ByteCount();
	u = EuclideanMultiplicativeInverse(q, p);
    assert(u*q%p==1);
}

RabinPrivateKey::RabinPrivateKey(BufferedTransformation &bt)
{
    BERSequenceDecoder seq(bt);
    n.BERDecode(seq);
    modulusLen = n.ByteCount();
    r.BERDecode(seq);
    s.BERDecode(seq);
    p.BERDecode(seq);
    q.BERDecode(seq);
    u.BERDecode(seq);
}

void RabinPrivateKey::DEREncode(BufferedTransformation &bt) const
{
    DERSequenceEncoder seq(bt);
    n.DEREncode(seq);
    r.DEREncode(seq);
    s.DEREncode(seq);
    p.DEREncode(seq);
    q.DEREncode(seq);
    u.DEREncode(seq);
}

void RabinPrivateKey::RawDecrypt(const Integer &in, Integer &out) const
{
	Integer cp=in%p, cq=in%q;

	int jp = Jacobi(cp, p);
	int jq = Jacobi(cq, q);

	if (jq==-1)
	{
		cp = cp*EuclideanMultiplicativeInverse(r, p)%p;
		cq = cq*EuclideanMultiplicativeInverse(r, q)%q;
	}

	if (jp==-1)
	{
		cp = cp*EuclideanMultiplicativeInverse(s, p)%p;
		cq = cq*EuclideanMultiplicativeInverse(s, q)%q;
	}

	cp = ModularSquareRoot(cp, p);
	cq = ModularSquareRoot(cq, q);

	if (jp==-1)
		cp = p-cp;

	out = CRT(cq, q, cp, p, u);

	if ((jq==-1 && out[0]==0) || (jq==1 && out[0]==1))
		out = n-out;
}
