"""
 Solving 3D nonlinear potential flow equations "
 Created 26 May 2016"
 Author : Floriane Gidel"
"""
import time
from discr_sigma import *
from firedrake import *


star_time = time.time()
"""
*********************************************
*    Definition of the space/time domain    *
********************************************* """
Lx = 2.0                                                # Length in x
Ly = 1.0                                                # Length in y
H0 = 1.0                                                # Depth at rest
t = 0.0                                                 # Initial time
Tend = 1.5                                              # Final time
g = 9.81
dt = 0.0028                                             # Time step

"""
*********************************************
*           Definition of the mesh          *
********************************************* """
#________ Vertical discretization _________#
n=2                                                     # Order of the expansion
Nz = n+1                                                # Number of point in one element

#_______ Horizontal discretization ________#
res = 0.1                                               # Mesh resolution
Nx = Lx/res +1                                          # Number of nodes in x
Ny = Ly/res +1                                          # Number of nodes in y
hor_mesh = RectangleMesh(Nx,Ny,Lx,Ly)                   # Surface mesh
mesh_3D = ExtrudedMesh(hor_mesh,1,layer_height = H0)    # Full mesh (3D)


"""
*********************************************
*      Definition of the Function Space     *
********************************************* """
V = FunctionSpace(hor_mesh, "CG",1)                                         # Function space at the surface        # used for psi_1 and h
V_3D = FunctionSpace(mesh_3D,"CG",1, vfamily="Lagrange", vdegree= n)        # Function space in the 3D domain      # used for phi(x,y,z,t) and varphi_tilde. In the vertical, we use
                                                                            # Lagrange polynomials of degree n, while in the horizontal we use CG expansion of degree 1.
Vec = VectorFunctionSpace(hor_mesh,"CG",1,dim=n)                            # Function space for vector functions  # used for Psi
Vec_3D = VectorFunctionSpace(mesh_3D,"CG",1, vfamily="Lagrange", vdegree= n)
W = V*Vec                                                                   # Mixed function space                 # used to solve simultaneous WF


"""
*********************************************
*        Definition of the Functions        *
********************************************* """
#_____________ Test functions _____________#
v = TestFunction(V)                                     # To solve WF at the surface only
w = TestFunction(W)                                     # To solve simultaneous WF; p is applied to the surface function, and q to the vector.
p, q = split(w)
r = TestFunction(Vec_3D)
#________________ Unknowns ________________#
# --- Time t^n --- :
h_n0 = Function(V)                                      # h at time n
psi_s_n0 = Function(V)                                  # psi_s at time n       (i.e surface value of phi(x,y,z,t))
hat_psi_n0 = Function(Vec)                              # hat(psi) at time n    (i.e. interior values of phi(x,y,z,t))

# --- Time t^{n+1/2} --- :
w1 = Function(W)
psi_s_half, hat_psi_half = split(w1)                    # psi_s and hat(psi) at time n+1/2

# --- Time t^{n+1} --- :
w2 = Function(W)
h_n1, hat_psi_n1 = split(w2)                            # h and hat(psi) at time n + 1

psi_s_n1 = Function(V)                                  # psi_s at time n + 1
phi = Function(V_3D)                                    # 3D potential phi(x,y,z,t)
varphi_tilde = Function(Vec_3D)
z_coord = Function(V_3D).interpolate(Expression('x[2]'))

#________________ Exact linear ________________#
h_exact = Function(V)
psi_s_exact = Function(V)


"""
*********************************************
*      Initialisation of the Functions      *
*********************************************"""
lam = pi/Lx
ww = sqrt(g*lam*tanh(lam*H0))
a = 0.1
b = 0.1
h_n0.interpolate(Expression("H0 -(1.0/g)*(-w*a*sin(w*t)+w*b*cos(w*t))*cos(lam*x[0])", H0=H0, g=g, w=ww, t=t, a=a, b=b, lam=lam))
w2.interpolate(Expression(("H0 -(1.0/g)*(-w*a*sin(w*t)+w*b*cos(w*t))*cos(lam*x[0])","0.0","0.0"), H0=H0, g=g, w=ww, t=t, a=a, b=b, lam=lam))
psi_s_n0.interpolate(Expression("(a*cos(w*t)+b*sin(w*t))*cos(lam*x[0])", a=a, b=b, w=ww, t=t, lam=lam))



"""
********************************************
*           Exact Linear solutions         *
******************************************** """
h_expr = Expression("H0 -(1.0/g)*(-w*a*sin(w*t)+w*b*cos(w*t))*cos(lam*x[0])", H0=H0, g=g, w=ww, t=t, a=a, b=b, lam=lam)
psi_s_expr = Expression("(a*cos(w*t)+b*sin(w*t))*cos(lam*x[0])", a=a, b=b, w=ww, t=t, lam=lam)

h_exact.interpolate(h_expr)
psi_s_exact.interpolate(psi_s_expr)



"""
********************************************
*               Saving Files               *
******************************************** """
save_h = File("data_quad2/h.pvd")
save_psi_s = File("data_quad2/psi_s.pvd")
save_hat_psi = File("data_quad2/hat_psi.pvd")
save_phi = File("data_quad2/phi.pvd")
save_h_ex = File("data_quad2/h_ex.pvd")
save_psi_s_ex = File("data_quad2/psi_s_ex.pvd")

"""
*******************************************
*           Compute the matrices          *
******************************************* """
#_____________ Initialization ____________#
A = np.eye(Nz,Nz)*0.0
M = np.eye(Nz,Nz)*0.0
D = np.eye(Nz,Nz)*0.0
S = np.eye(Nz,Nz)*0.0

#______ Evaluation of each integral ______#
for i in range(0,Nz):
    for j in range(0,Nz):
        A[i,j]=A_ij(i,j,n,H0)
        M[i,j]=M_ij(i,j,n,H0)
        D[i,j]=D_ij(i,j,n,H0)
        S[i,j]=S_ij(i,j,n,H0)

#______________ Submatrices ______________#
A11 = A[0,0]
A1N = as_tensor(A[0,1:])
AN1 = as_tensor(A[1:,0])
ANN = as_tensor(A[1:,1:])

M11 = M[0,0]
M1N = as_tensor(M[0,1:])
MN1 = as_tensor(M[1:,0])
MNN = as_tensor(M[1:,1:])

D11 = D[0,0]
D1N = as_tensor(D[0,1:])
DN1 = as_tensor(D[1:,0])
DNN = as_tensor(D[1:,1:])

S11 = S[0,0]
S1N = as_tensor(S[0,1:])
SN1 = as_tensor(S[1:,0])
SNN = as_tensor(S[1:,1:])

Ii = []
for i in range(n):
    Ii.append(1.0)
I3 = as_tensor(Ii)

"""
*******************************************
*   Compute the vertical basis functions  *
*******************************************
----->
----->
-----> DEFINE AND COMPUTE varphi_tilde ; Is there a "product" expression to get varphi(z) as an expression of z ?
----->
----->
"""



"""
********************************************
*       Define the weak formulations       *
******************************************** """


# --- Step 1 : get psi_s(t^{n+1/2}) and hat_psi (t^{n+1/2})

WF_psi_s = (p*psi_s_half - p*psi_s_n0 + (dt/(4*H0))*( p*M11*(nabla_grad(psi_s_half))**2 -2*D11*psi_s_half*dot(nabla_grad(p),nabla_grad(psi_s_half)) + S11*(psi_s_half**2)*( (2/h_n0)*dot(nabla_grad(h_n0),nabla_grad(p)) - (p/(h_n0**2))*(nabla_grad(h_n0))**2) -p*A11*(psi_s_half**2)*(H0/h_n0)**2 + 2*g*H0*p*(h_n0-H0)  +2*p*dot(dot(MN1,nabla_grad(hat_psi_half).T),nabla_grad(psi_s_half)) -2*dot(hat_psi_half,D1N)*dot(nabla_grad(p),nabla_grad(psi_s_half)) + 2*psi_s_half*dot(hat_psi_half,SN1)*((2/h_n0)*dot(nabla_grad(h_n0),nabla_grad(p))- (p/h_n0**2)*(nabla_grad(h_n0))**2) -2*p*dot(hat_psi_half,AN1)*psi_s_half*(H0/h_n0)**2  -2*psi_s_half*dot(DN1,dot(nabla_grad(hat_psi_half).T,nabla_grad(p)))  +p*inner(dot(nabla_grad(hat_psi_half),MNN),nabla_grad(hat_psi_half)) -2*dot(dot(nabla_grad(p),nabla_grad(hat_psi_half)), dot(DNN,hat_psi_half)) +dot(dot(hat_psi_half,SNN),hat_psi_half)*((2/h_n0)*dot(nabla_grad(h_n0),nabla_grad(p))-(p/(h_n0**2))*(nabla_grad(h_n0))**2) -p*dot(dot(hat_psi_half,ANN),hat_psi_half)*(H0/h_n0)**2))*dx


WF_hat_psi_half = elem_mult(MN1,dot(nabla_grad(q).T,nabla_grad(psi_s_half)))*h_n0 + elem_mult(SN1,q)*psi_s_half*dot(nabla_grad(h_n0),nabla_grad(h_n0))/h_n0  - elem_mult(D1N,q)*dot(nabla_grad(h_n0),nabla_grad(psi_s_half)) - elem_mult(DN1,dot(nabla_grad(q).T,nabla_grad(h_n0)))*psi_s_half + elem_mult(AN1,q)*psi_s_half*H0*H0/h_n0 + dot(elem_mult(MNN,dot(nabla_grad(q).T,nabla_grad(hat_psi_half))),I3)*h_n0 + elem_mult(dot(SNN,hat_psi_half),q)*dot(nabla_grad(h_n0),nabla_grad(h_n0))/h_n0 - elem_mult(dot(DNN.T,dot(nabla_grad(hat_psi_half).T,nabla_grad(h_n0))),q) - elem_mult(dot(DNN,hat_psi_half),dot(nabla_grad(q).T,nabla_grad(h_n0))) + elem_mult(dot(ANN,hat_psi_half),q)*H0*H0/h_n0

WF1 = WF_psi_s + sum((WF_hat_psi_half[ind])*dx for ind in range(0,n))


# --- Step 2 : get h^t^{n+1} and hat_psi(t^{n+1})

WF_h = ( p*h_n1 - p*h_n0 - (dt/(2*H0))*(M11*dot(nabla_grad(psi_s_half),nabla_grad(p))*(h_n0+h_n1) + S11*p*psi_s_half*(((nabla_grad(h_n0))**2)/h_n0 + ((nabla_grad(h_n1))**2)/h_n1) -D11*p*(dot(nabla_grad(psi_s_half),nabla_grad(h_n0)) + dot(nabla_grad(psi_s_half), nabla_grad(h_n1))) -D11*psi_s_half*( dot(nabla_grad(p),nabla_grad(h_n0)) + dot(nabla_grad(p),nabla_grad(h_n1))) +(H0**2)*A11*psi_s_half*p*(1.0/h_n0 + 1.0/h_n1) + h_n0*dot(MN1,dot(nabla_grad(hat_psi_half).T,nabla_grad(p))) + h_n1*dot(MN1,dot(nabla_grad(hat_psi_n1).T,nabla_grad(p)))  +p*( dot(hat_psi_half,SN1)*(nabla_grad(h_n0))**2/h_n0 + dot(hat_psi_n1,SN1)*(nabla_grad(h_n1))**2/h_n1) -  dot(D1N,hat_psi_half)*dot(nabla_grad(h_n0),nabla_grad(p)) - dot(D1N,hat_psi_n1)*dot(nabla_grad(h_n1),nabla_grad(p)) -p*( dot(DN1,dot(nabla_grad(hat_psi_half).T,nabla_grad(h_n0))) + dot(DN1, dot(nabla_grad(hat_psi_n1).T,nabla_grad(h_n1))))  +(H0**2)*p*(dot(hat_psi_half,AN1)/h_n0 + dot(hat_psi_n1,AN1)/h_n1)))*dx

WF_hat_psi_n1 = elem_mult(M1N,dot(nabla_grad(q).T,nabla_grad(psi_s_half)))*h_n1 + elem_mult(S1N,q)*psi_s_half*dot(nabla_grad(h_n1),nabla_grad(h_n1))/h_n1 - elem_mult(D1N,q)*dot(nabla_grad(h_n1),nabla_grad(psi_s_half)) - elem_mult(DN1,dot(nabla_grad(q).T,nabla_grad(h_n1)))*psi_s_half + elem_mult(A1N,q)*psi_s_half*H0*H0/h_n1 + dot(elem_mult(MNN.T,dot(nabla_grad(q).T,nabla_grad(hat_psi_n1))),I3)*h_n1 + elem_mult(dot(SNN.T,hat_psi_n1),q)*dot(nabla_grad(h_n1),nabla_grad(h_n1))/h_n1 - elem_mult(dot(DNN.T,dot(nabla_grad(hat_psi_n1).T,nabla_grad(h_n1))),q) -elem_mult(dot(DNN,hat_psi_n1),dot(nabla_grad(q).T,nabla_grad(h_n1))) + elem_mult(dot(ANN.T,hat_psi_n1),q)*H0*H0/h_n1

WF2 = WF_h + sum((WF_hat_psi_n1[ind])*dx for ind in range(0,n))

# --- Step 3 : get psi_s(t^{n+1])

WF3 = (v*psi_s_n1 - v*psi_s_half + (dt/(4.0*H0))*(v*M11*nabla_grad(psi_s_half)**2 -2*psi_s_half*dot(nabla_grad(psi_s_half),nabla_grad(v))*D11 + S11*(psi_s_half**2)*( 2.0*dot(nabla_grad(h_n1),nabla_grad(v))/h_n1 - v*(nabla_grad(h_n1)**2)/(h_n1**2)) - v*A11*(H0*psi_s_half/h_n1)**2 +2.0*g*H0*v*(h_n1-H0) +2*v*dot(MN1,dot(nabla_grad(hat_psi_n1).T,nabla_grad(psi_s_half))) -2*dot(hat_psi_n1,D1N)*dot(nabla_grad(psi_s_half),nabla_grad(v)) + 2*psi_s_half*dot(hat_psi_n1,SN1)*(2.0*dot(nabla_grad(h_n1),nabla_grad(v))/h_n1 - v*(nabla_grad(h_n1)/h_n1)**2 ) -2*v*psi_s_half*dot(hat_psi_n1,AN1)*(H0/h_n1)**2 -2*psi_s_half*dot(DN1,dot(nabla_grad(hat_psi_n1).T,nabla_grad(v))) +dot(hat_psi_n1,dot(SNN,hat_psi_n1))*(2*dot(nabla_grad(h_n1),nabla_grad(v))/h_n1 - v*(nabla_grad(h_n1)/h_n1)**2) -v*dot(hat_psi_n1,dot(ANN,hat_psi_n1))*(H0/h_n1)**2))*dx



"""
********************************************
*        Solve the weak formulations       *
******************************************** """

while t<Tend:
    print 100*t/Tend, "%"
#_______________ Save data _______________#
    save_h << h_n0
    save_psi_s << psi_s_n0
    save_h_ex << h_exact
    save_psi_s_ex << psi_s_exact
#    save_hat_psi << hat_psi_n0
#    save_phi << phi
#______________ Update time ______________#
    t += dt

    
#______ Solve the weak formulations ______#
    solve(WF1 == 0, w1)
    solve(WF2 == 0, w2)
    solve(WF3 == 0, psi_s_n1)

#__________ Update the solutions _________#
    # ---- Extract the solutions ---- #
    h_out, hat_psi_out = w2.split()
   
   # --- Update --- #
    h_n0.assign(h_out)
    psi_s_n0.assign(psi_s_n1)
#    hat_psi_n0.assign(hat_psi_out) # not necessary

#__________ Update the exact solutions _________#
    h_expr.t = t
    psi_s_expr.t = t
    h_exact.interpolate(h_expr)
    psi_s_exact.interpolate(psi_s_expr)

print time.time()-star_time

"""
----->
----->
-----> COMPUTE phi = hat_psi * varphi_tilde (dot product)

"""