"""
    Created on Thu May 12 2016
    
    @author: Floriane Gidel
    """

import numpy as np
from sympy import *


def varphi_expr(i, n, H0):
    z=Symbol('z')
    k = Symbol('k')
    sigma_k = (n-k)*H0/n
    sigma = lambdify(k,sigma_k,"numpy")
    varphi_z = (Product((z-sigma(k))/(sigma(i)-sigma(k)),(k, 0,i-1))*Product((z-sigma(k))/(sigma(i)-sigma(k)),(k, i+1,n))).doit() #this avoids to evaluate the case k=i.
    return varphi_z

def deriv_varphi_expr(varphi_expr):
    z = Symbol('z')
    deriv = diff(varphi_expr,z)
    return deriv
def WM_i(i,n,H0):
    z=Symbol('z')
    expr_WM = varphi_expr(i,n,H0)
    WM = integrate(expr_WM, (z,0,H0))
    return WM

def M_ij(i,j,n,H0):
    z=Symbol('z')
    expr_M = varphi_expr(i,n,H0)*varphi_expr(j,n,H0)
    M = integrate(expr_M, (z,0,H0))
    return M

def A_ij(i,j,n,H0):
    z=Symbol('z')
    expr_A = deriv_varphi_expr(varphi_expr(i,n,H0))*deriv_varphi_expr(varphi_expr(j,n,H0))
    A = integrate(expr_A, (z,0,H0))
    return A

def D_ij(i,j,n,H0): # ATTENTION NON SYMMETRIC !
    z=Symbol('z')
    expr_D = z*varphi_expr(i,n,H0)*deriv_varphi_expr(varphi_expr(j,n,H0))
    D = integrate(expr_D, (z,0,H0))
    return D

def S_ij(i,j,n,H0):
    z=Symbol('z')
    expr_D = z*z*deriv_varphi_expr(varphi_expr(i,n,H0))*deriv_varphi_expr(varphi_expr(j,n,H0))
    D = integrate(expr_D, (z,0,H0))
    return D


def test_calcul_matrices(H0):
    # Validate the matrices against analitycal solution in the case of n=2, i=0 and j=1. For any value of H0.
    print("----")
    n=2
    z=Symbol('z')

    i=0
    varphi_i_expr=varphi_expr(i, n, H0)
    varphi_i = lambdify(z,varphi_i_expr,"numpy")

    d_varphi_i_expr = deriv_varphi_expr(varphi_i_expr)
    d_varphi_i = lambdify(z,d_varphi_i_expr,"numpy")

    expr_varphi_0 = (z-H0/2.0)*(z-H0)/(H0*H0/2)
    eexpr_varphi_0 = lambdify(z,expr_varphi_0,"numpy")

    expr_deriv_0 = 2.0*(2*z-3*H0/2)/(H0*H0)
    eexpr_deriv_0 = lambdify(z,expr_deriv_0,"numpy")

    j = 1
    varphi_j_expr=varphi_expr(j, n, H0)
    varphi_j = lambdify(z,varphi_j_expr,"numpy")

    d_varphi_j_expr = deriv_varphi_expr(varphi_j_expr)
    d_varphi_j = lambdify(z,d_varphi_j_expr,"numpy")

    expr_varphi_1 = (z*H0-z*z)*(4.0/(H0*H0))
    eexpr_varphi_1 = lambdify(z,expr_varphi_1,"numpy")

    expr_deriv_1 = (H0-2.0*z)*(4.0/(H0*H0))
    eexpr_deriv_1 = lambdify(z,expr_deriv_1,"numpy")


## Aij
    expr_Aij = (8.0/(H0**4.0))*((5.0/2.0)*H0*z*z  -(4.0/3.0)*(z**3) - (3.0/2.0)*(H0**2)*z)
    eexpr_Aij= lambdify(z,expr_Aij,"numpy")
    print(" test Aij:")
    print Aij(i,j,n,H0)
    print eexpr_Aij(H0)
    print("----")

## Mij
    expr_Mij=(-8.0/H0**4)*((1.0/5.0)*z**5 - (5.0/8.0)*H0*(z**4) +(2.0/3.0)*(H0*H0)*(z*z*z) -(1.0/4.0)*z*z*H0**3)
    eexpr_Mij = lambdify(z,expr_Mij,"numpy")
    print(" test Mij:")
    print Mij(i,j,n,H0)
    print eexpr_Mij(H0)
    print("----")

## Bij
    expr_Bij=  (8/H0**4)*((-1.0/2.0)*z**4 + (4.0/3.0)*H0*z**3 -(5.0/4.0)*H0*H0*z**2 +(H0**3/2.0)*z )
    eexpr_Bij = lambdify(z,expr_Bij,"numpy")
    print(" test Bij:")
    print Bij(i,j,n,H0)
    print eexpr_Bij(H0)
    print("----")

## Bji
    expr_Bji=  (8/H0**4)*((-1.0/2.0)*z**4 + (7.0/6.0)*H0*z**3 -(3.0/4.0)*H0*H0*z**2 )
    eexpr_Bji = lambdify(z,expr_Bji,"numpy")
    print(" test Bji:")
    print Bij(j,i,n,H0)
    print eexpr_Bji(H0)
    print("----")

## Cij
    expr_Cij=(8/H0**4)*((5.0/3.0)*H0*z**3 - z**4 - (3.0/4.0)*H0**2*z**2)
    eexpr_Cij = lambdify(z,expr_Cij,"numpy")
    print(" test Cij:")
    print Cij(i,j,n,H0)
    print eexpr_Cij(H0)
    print("----")

## Dij
    expr_Dij=(8.0/H0**4)*((-2.0/5.0)*z**5 + H0*z**4 -(5.0/6.0)*H0*H0*z**3 +(1.0/4.0)*H0**3*z**2 )
    eexpr_Dij = lambdify(z,expr_Dij,"numpy")
    print(" test Dij:")
    print Dij(i,j,n,H0)
    print eexpr_Dij(H0)
    print("----")

## Dji
    expr_Dji=(8.0/H0**4)*((7.0/8.0)*H0*z**4 - 0.5*H0**2*z**3 - (2.0/5.0)*z**5 )
    eexpr_Dji = lambdify(z,expr_Dji,"numpy")
    print(" test Dji:")
    print Dij(j,i,n,H0)
    print eexpr_Dji(H0)
    print("----")

## Sij
    expr_Sij=(8.0/H0**4)*((-4.0/5.0)*z**5 + (5.0/4.0)*H0*z**4 -0.5*H0**2*z**3 )
    eexpr_Sij = lambdify(z,expr_Sij,"numpy")
    print(" test Sij:")
    print Sij(i,j,n,H0)
    print eexpr_Sij(H0)
    print("----")


