//
// LiDIA - a library for computational number theory
//   Copyright (c) 1994, 1995 by the LiDIA Group
//
// File        : poly_modulus.c
// Author      : Victor Shoup, Thomas Pfahler (TPf)
// Last change : TPf, Feb 29, 1996, initial version
//

#if defined(HAVE_MAC_DIRS) || defined(__MWERKS__)
#include <LiDIA:poly_modulus.h>
#include <LiDIA:poly_multiplier.h>
#else
#include <LiDIA/poly_modulus.h>
#include <LiDIA/poly_multiplier.h>
#endif



/******************************************************************

			  class poly_modulus

*******************************************************************/
// If you need to do a lot of arithmetic modulo a fixed f,
// build poly_modulus F for f.  This pre-computes information about f
// that speeds up the computation a great deal.
// f should be monic, and deg(f) > 0.

void poly_modulus::build(const Fp_polynomial& pol)
{
    debug_handler( "poly_modulus", "build( Fp_polynomial& )" );

//    if ( f == pol )
//  	return; 	//initializing with the same polynomial again

    f.assign( pol );
    deg_f = pol.degree();

    if (deg_f <= 0)
	lidia_error_handler( "poly_modulus", "build( Fp_polynomial& )::f.degree() must be at least 1" );
    if ( !pol.is_monic() )
	lidia_error_handler( "poly_modulus", "build( Fp_polynomial& )::modulus must be monic" );

    crov = Fp_polynomial::crossovers.fftmul_crossover(f.modulus());
    if (deg_f <= crov)
    {
	use_FFT = false;
	return;
    }

    use_FFT = true;

    k = next_power_of_two(deg_f);
    l = next_power_of_two(2*deg_f - 3);

    lidia_size_t l_safe = next_power_of_two(2*deg_f - 1);
    //in functions multiply,square it must be guaranteed that both fft_reps
    //use an fft_table of at least this size
    
    F.init(l_safe, pol.MOD);

    FRep.init(k, F);
    FRep.to_fft_rep(pol);

    Fp_polynomial P1, P2;
    P1.set_max_degree(deg_f);
    P2.set_max_degree(deg_f-1);

    copy_reverse(P1, pol, 0, deg_f);
    invert(P2, P1, deg_f-1);
    copy_reverse(P1, P2, 0, deg_f-2);
    
    HRep.init(l, F);
    HRep.to_fft_rep(P1);

}


#if 0
void poly_modulus::f_star(Fp_polynomial &a) const
{
    Fp_polynomial P1, P2;
    copy_reverse(P1, copy_of_f, 0, copy_of_f.degree());
    invert(P2, P1, copy_of_f.degree());
    copy_reverse(a, P2, 0, copy_of_f.degree()-2);
}
#endif



// x = a % f
// deg(a) <= 2(n-1), where n = f.degree() 
void poly_modulus::rem21(Fp_polynomial& x, const Fp_polynomial& a) const
{
    debug_handler( "poly_modulus", "rem21( Fp_polynomial&, Fp_polynomial& )" );
    lidia_size_t i, index, K, da, ds, n;

    da = a.degree();
    n = deg_f;

    if (da > 2*n-2)
	lidia_error_handler( "poly_modulus", "rem21( Fp_polynomial&, Fp_polynomial& )::bad_args" );

    if (da < n)
    {
	x.assign( a );
	return;
    }

    if (!use_FFT || (da - n) <= crov)
    {
	plain_rem(x, a, f);
	return;
    }

    modular_fft_rep R1(HRep);
    Fp_polynomial P1;
    P1.set_max_degree(n-1);


    for (index = 0; index < R1.number_of_primes(); index++)
    {
	R1.to_modular_fft_rep(a, n, 2*(n-1), index);
	multiply(R1, HRep, R1, index);
	R1.from_modular_fft_rep(n-2, 2*n-4, index);
    }
    R1.get_result(P1, n-2, 2*n-4);


    R1.set_size(k);
    for (index = 0; index < R1.number_of_primes(); index++)
    {
	R1.to_modular_fft_rep(P1, index);
	multiply(R1, FRep, R1, index);
	R1.from_modular_fft_rep(0, n-1, index);
    }
    R1.get_result(P1, 0, n-1);


    ds = P1.degree();
    K = 1 << k;

    x.MOD = a.MOD;
    x.set_degree(n-1);
    const bigint* aa = a.coeff;
    const bigint* ss = P1.coeff;
    bigint* xx = x.coeff;
    const bigint & p = P1.modulus();

    for (i = 0; i < n; i++)
    {
	if (i <= ds)
	    SubMod(xx[i], aa[i], ss[i], p);
	else
	    xx[i].assign( aa[i] );

	if (i + K <= da)
	    AddMod(xx[i], xx[i], aa[i+K], p);
    }

    x.remove_leading_zeros();
}



// x = a % f, no restrictions on deg(a);  makes repeated calls to rem21
void remainder(Fp_polynomial& x, const Fp_polynomial& a, const poly_modulus& F)
{
    debug_handler( "poly_modulus", "remainder( Fp_polynomial&, Fp_polynomial&, poly_modulus& )" );

    a.comp_modulus(F.modulus(), "remainder");

    lidia_size_t da = a.degree();
    lidia_size_t n = F.deg();

    if (da <= 2*n-2)
    {
	F.rem21(x, a);
	return;
    }
    else
	if (!F.use_fft())
	{
	    plain_rem(x, a, F.modulus());
	    return;
	}

    Fp_polynomial buf;
    F.forward_modulus(buf);		//buf.MOD = F.f->MOD;
    buf.set_max_degree(2*n-2);

    lidia_size_t a_len = da+1;

    while (a_len > 0)
    {
	lidia_size_t old_buf_len = buf.degree() + 1;
	lidia_size_t amt = comparator<lidia_size_t>::min(2*n - 1 - old_buf_len, a_len);

	buf.set_degree(old_buf_len+amt-1);

	lidia_size_t i;

	for (i = old_buf_len+amt-1; i >= amt; i--)
	    buf[i].assign( buf[i-amt] );

	for (i = amt-1; i >= 0; i--)
	    buf[i].assign( a[a_len-amt+i] );
	buf.remove_leading_zeros();

	F.rem21(buf, buf);
	a_len -= amt;
    }
    x.assign( buf );
}



// x = (a * b) % f
// deg(a), deg(b) < deg_f
void multiply(Fp_polynomial& x, const Fp_polynomial& a, const Fp_polynomial& b, const poly_modulus& F)
{
    debug_handler( "poly_modulus", "multiply( Fp_polynomial&, Fp_polynomial&, Fp_polynomial&, poly_modulus& )" );

    a.comp_modulus(b, "multiply");
    a.comp_modulus(F.modulus(), "multiply");

    lidia_size_t  da, db, d, n, k, index;

    da = a.degree();
    db = b.degree();
    n = F.deg();

    if (da >= n || db >= n)
	lidia_error_handler( "poly_modulus", "multiply( Fp_polynomial&, Fp_polynomial&, Fp_polynomial&, poly_modulus& )::degree of Fp_polynomials must be < degree of poly_modulus" );

    if (!F.use_fft() || da + db - n <= F.crov)
    {
	Fp_polynomial P1, P2;
	plain_mul(P1, a, b);
	plain_div_rem(P2, x, P1, F.modulus());
	return;
    }

    d = da + db + 1;
    k = next_power_of_two(d);
    k = comparator<lidia_size_t>::max(k, F.k);
    //here, k <= next_power_of_two( 2*deg_f - 1 )
    

    fft_rep R1(F.HRep);
    R1.set_size(k);
    modular_fft_rep R2(F.HRep), R3(F.FRep);
    Fp_polynomial P1;
    P1.set_max_degree(n - 1);

    R2.set_size(k);

    R1.to_fft_rep(a);
    for (index=0; index < R2.number_of_primes(); index++)
    {
	R2.to_modular_fft_rep(b, index);
	multiply(R2, R1, R2, index);
	reduce(R1, R2, k, index);	// = copy
	R2.from_modular_fft_rep(n, d-1, index);
    }
    R2.get_result(P1, n, d-1);


    R2.set_size(F.l);
    for (index=0; index < R2.number_of_primes(); index++)
    {
	R2.to_modular_fft_rep(P1, index);
	multiply(R2, F.HRep, R2, index);
	R2.from_modular_fft_rep(n-2, 2*n-4, index);
    }
    R2.get_result(P1, n-2, 2*n-4);


    R2.set_size(F.k);
    for (index=0; index < R2.number_of_primes(); index++)
    {
	R3.to_modular_fft_rep(P1, index);
	multiply(R3, F.FRep, R3, index);
	reduce(R2, R1, F.k, index);
	subtract(R2, R2, R3, index);
	R2.from_modular_fft_rep(0, n-1, index);
    }
    R2.get_result(x, 0, n-1);
}



// x = a^2 % f			a.degree() < f.degree()
void square(Fp_polynomial& x, const Fp_polynomial& a, const poly_modulus& F)
{
    debug_handler( "poly_modulus", "square( Fp_polynomial&, Fp_polynomial&, poly_modulus& )" );

    a.comp_modulus(F.modulus(), "square");

    lidia_size_t  da, d, n, k, index;

    da = a.degree();
    n = F.deg();

    if (da >= n)
	debug_handler( "poly_modulus", "square( Fp_polynomial&, Fp_polynomial&, poly_modulus& )::degree of Fp_polynomial must be < degree of poly_modulus" );

    if (!F.use_fft() || 2*da - n <= F.crov)
    {
	Fp_polynomial P1, P2;
	plain_sqr(P2, a);
	plain_div_rem(P1, x, P2, F.modulus());
	return;
    }


    d = 2*da + 1;

    k = next_power_of_two(d);
    k = comparator<lidia_size_t>::max(k, F.k);
    //here, k <= next_power_of_two( 2*deg_f - 1 )


    fft_rep R1(F.HRep);
    R1.set_size(k);
    modular_fft_rep R2(F.HRep), R3(F.FRep);
    Fp_polynomial P1;
    P1.set_max_degree(n - 1);

    R1.to_fft_rep(a);
    multiply(R1, R1, R1);
    R1.from_fft_rep(P1, n, d-1);  // save R1 for future use

    for (index=0; index < R2.number_of_primes(); index++)
    {
	R2.to_modular_fft_rep(P1, index);
	multiply(R2, F.HRep, R2, index);
	R2.from_modular_fft_rep(n-2, 2*n-4, index);
    }
    R2.get_result(P1, n-2, 2*n-4);

    R2.set_size(F.k);
    for (index=0; index < R2.number_of_primes(); index++)
    {
	R3.to_modular_fft_rep(P1, index);
	multiply(R3, F.FRep, R3, index);
	reduce(R2, R1, F.k, index);
	subtract(R2, R2, R3, index);
	R2.from_modular_fft_rep(0, n-1, index);
    }
    R2.get_result(x, 0, n-1);
}



// x = a^e % f
void power(Fp_polynomial& h, const Fp_polynomial& g, const bigint& e, const poly_modulus& F)
{
    debug_handler( "poly_modulus", "power( Fp_polynomial&, Fp_polynomial&, bigint&, poly_modulus& )" );

    g.comp_modulus(F.modulus(), "power");
    
    if (e.is_negative())
	lidia_error_handler( "poly_modulus", "power( Fp_polynomial&, Fp_polynomial&, bigint&, poly_modulus& )::exponent must be positive" );

    int i, n = e.bit_length();

    poly_multiplier G(g, F);

    h.set_max_degree(F.deg() - 1);
    F.forward_modulus(h);	//h.MOD = F.f->MOD;
    h.assign_one();

    for (i = n - 1; i >= 0; i--)
    {
	square(h, h, F);
	if (e.bit(i))
	multiply(h, h, G, F);	//poly_multiplier
    }
}


// x = X^e % f
void power_x(Fp_polynomial& h, const bigint& e, const poly_modulus& F)
{
    debug_handler( "poly_modulus", "power_x( Fp_polynomial&, bigint&, poly_modulus& )" );
	
    if (e.is_negative())
	lidia_error_handler( "poly_modulus", "power_x( Fp_polynomial&, bigint&, poly_modulus& )::exponent must be positive" );

    int i, n = e.bit_length();

    h.set_max_degree(F.deg() - 1);
    F.forward_modulus(h);	//h.MOD = F.f->MOD;
    h.assign_one();

    for (i = n - 1; i >= 0; i--)
    {
	square(h, h, F);
	if (e.bit(i))
	    multiply_by_x_mod(h, h, F.modulus());
    }
}


// x = (X + a)^e % f
void power_x_plus_a(Fp_polynomial& h, const bigint& a, const bigint& e, const poly_modulus& F)
{
    debug_handler( "poly_modulus", "power_x_plus_a( Fp_polynomial&, bigint&, bigint&, poly_modulus& )" );

    if (e.is_negative())
	lidia_error_handler( "poly_modulus", "power_x_plus_a( Fp_polynomial&, bigint&, bigint&, poly_modulus& )::exponent must be positive" );


    Fp_polynomial t1, t2;
    t1.set_max_degree(F.deg()-1);
    t2.set_max_degree(F.deg()-1);

    bigint la;
    Remainder(la, a, F.modulus().modulus());	//allows input to alias output

    int i, n = e.bit_length();

    h.set_max_degree(F.deg() - 1);
    F.forward_modulus(h);		//h.MOD = F.f->MOD;
    h.assign_one();

    for (i = n - 1; i >= 0; i--)
    {
	square(h, h, F);
	if (e.bit(i))
	{
	    multiply_by_x_mod(t1, h, F.modulus());
	    multiply_by_scalar(t2, h, la);
	    add(h, t1, t2);
	}
    }
}




