
import time
import numpy as np 
from vertical_discr import *
from firedrake import *

"""
    *********************************************
    *    Definition of the space/time domain    *
    ********************************************* """

#______________________ Beach ______________________#
xb = 7.5                         # Start of the beach
sb = 0.2                         # Slope of the beach

#______________________ Basin ______________________#
g = 9.81                     # Gravitational constant
H0 = 1.0                # Depth at rest (flat bottom)
Hend = 0.5            # Depth at the end of the beach
Lx = xb +(H0-Hend)/sb                   # Length in x
Ly = 2.0                                # Length in y
Lw = 4.0                     # End of the x-transform

#____________________ Wavemaker ____________________#
lamb = 2.0                               # Wavelength
k = 2*pi/lamb                           # Wave number
w = sqrt(g*k*tanh(k*H0))             # Wave frequency
Tw = 2*pi/w                             # Wave period
gamma = 0.03                         # Wave amplitude
t_stop = 5.0*Tw          # When to stop the wavemaker

#______________________ Time _______________________#
T0 = 0.0                               # Initial time
t = T0                            # Temporal variable
Tend = 15*Tw                             # Final time
dt = 0.002                                # Time-step
scheme ="SE"                            # Time scheme

"""
    ****************************************************
    *               Definition of the mesh             *
    **************************************************** """

#_________________ Vertical discretization ________________#
n_z = 4                             # Order of the expansion
Nz = n_z+1                  # Number of point in one element

#________________ Horizontal discretization _______________#
res = 0.05                                      # resolution
Nx = int(Lx/res) + 1               # Number of elements in x
Ny = int(Ly/res) + 1               # Number of elements in y

#___________________________ Mesh _________________________#
if Ly==0:                                       #(x,z)-waves
    hor_mesh = IntervalMesh(Nx,Lx)
else:                                         #(x,y,z)-waves
    hor_mesh = RectangleMesh(Nx,Ny,Lx,Ly,quadrilateral=True)

"""
    ************************************************
    *       Definition of the function spaces      *
    ************************************************ """
#___________________ For h and psi_1 __________________#
V = FunctionSpace(hor_mesh, "CG", 1)
#_____________________ For hat_psi ____________________#
Vec = VectorFunctionSpace(hor_mesh, "CG", 1,dim=n_z)
#________________ Mixed function space ________________#
V_mixed = V*Vec # to update h and hat_psi simultaneously


"""
    ******************************************************
    *            Definition of the functions             *
    ****************************************************** """

if scheme=="SE": #_________ Symplectic-Euler scheme _________#
    #______________________ At time t^n _____________________#
    h_n0 = Function(V)                                   # h^n
    psi_1_n0 = Function(V)                           # psi_1^n
    hat_psi_n0 = Function(Vec)                     # hat_psi^n

    #________________ At time t^{n+1} and t^* _______________#
    psi_1_n1 = Function(V)                       # psi_1^{n+1}
    w_n1 = Function(V_mixed)
    h_n1, hat_psi_star = split(w_n1)      # h^{n+1}, hat_psi^*
    hat_psi_n1 = Function(Vec)                 # hat_psi^{n+1}
else: #________________ Stormer-Verlet scheme _______________#
    #______________________ At time t^n _____________________#
    h_n0 = Function(V)                                   # h^n
    psi_1_n0 = Function(V)                           # psi_1^n
    hat_psi_n0 = Function(Vec)                     # hat_psi^n

    #_______________ At time t^{n+1/2} and t^* ______________#
    w_half = Function(V_mixed)
    h_half, hat_psi_star = split(w_half)# h^{n+1/2}, hat_psi^*

    #____________________ At time t^{n+1} ___________________#
    psi_1_n1 = Function(V)                       # psi_1^{n+1}
    w_n1 = Function(V_mixed)
    h_n1, hat_psi_n1 = split(w_n1)    # h^{n+1}, hat_psi^{n+1}

x_coord = Function(V).interpolate(Expression('x[0]'))

#___________________________ Beach __________________________#
beach = Function(V)                                     # b(x)

#_______________________ Depth at rest ______________________#
H = Function(V)                                         # H(x)

#_________________________ Wavemaker ________________________#
WM = Function(V)                                  # R(x,y;t^n)
dWM_dt = Function(V)                               # (dR/dt)^n
dWM_dy = Function(V)                               # (dR/dy)^n
WM_n1 = Function(V)
# NEED TO EVALUATE WM DIFFERENTLY SELON LE TIME SCHEME

#______________________ Trial functions ______________________#
psi_1 = TrialFunction(V)       # psi_1^{n+1} for linear solvers
hat_psi = TrialFunction(Vec)   # hat_psi^{n+1} for linear solvers

#_______________________ Test functions ______________________#
delta_h = TestFunction(V)                          # from dH/dh
delta_hat_psi = TestFunction(Vec)            # from dH/dhat_psi
w_t = TestFunction(V_mixed)                 # from dH/dpsi_1...
delta_psi, delta_hat_star = split(w_t)     # ...and dH/dhat_psi


"""
    *************************************************************************
    *                     Initialisation of the Functions                   *
    *************************************************************************"""
#______________________________ Beach topography ______________________________#
beach_expr = \
Expression("0.5*(1+copysign(1.0,x[0]-xb))*slope*(x[0]-xb)",xb=xb, slope=sb)
beach.interpolate(beach_expr)                                             # b(x)

#________________________________ Depth at rest _______________________________#
#H.assign(H0-beach)                                                        # H(x)
H.interpolate(Expression("H0-0.5*(1+copysign(1.0,x[0]-xb))*slope*(x[0]-xb)",xb=xb, slope=sb))
dim = 3
if dim==2:
    #______________________ Wavemaker motion and derivatives ______________________#
    WM_expr = \
    Expression("-0.5*(1+copysign(1.0,Lw-x[0]))*A*cos(w*t)",\
    A=gamma, Lw = Lw, w=w, t=t)
    WM.interpolate(WM_expr)                                       # \tilde{R}(x,y;t)
   
    WM_n1_expr = \
    Expression("-0.5*(1+copysign(1.0,Lw-x[0]))*A*cos(w*t)",\
               A=gamma, Lw = Lw, w=w, t=t+dt)
    WM_n1.interpolate(WM_n1_expr)
    
    dWM_dt_expr = \
    Expression("0.5*(1+copysign(1.0,Lw-x[0]))*A*w*sin(w*t)",\
              A=gamma,Lw=Lw, w=w, t=t)
    dWM_dt.interpolate(dWM_dt_expr)                                  # d\tilde{R}/dt

    dWM_dy_expr = Expression("0.0")
    dWM_dy.interpolate(dWM_dy_expr)
else:
    #______________________ Wavemaker motion and derivatives ______________________#
    WM_expr = \
    Expression("-0.5*(1+copysign(1.0,Lw-x[0]))*A*(x[1]-0.5*Ly)/(0.5*Ly)*cos(w*t)",\
               A=gamma, Ly=Ly, Lw = Lw, w=w, t=t)
    WM.interpolate(WM_expr)                                       # \tilde{R}(x,y;t)

    dWM_dt_expr = \
    Expression("0.5*(1+copysign(1.0,Lw-x[0]))*A*w*(x[1]-0.5*Ly)/(0.5*Ly)*sin(w*t)",\
               A=gamma, Ly=Ly, Lw=Lw, w=w, t=t)
    dWM_dt.interpolate(dWM_dt_expr)                                  # d\tilde{R}/dt

    dWM_dy_expr = Expression("-0.5*(1+copysign(1.0,Lw-x[0]))*A*cos(w*t)/(0.5*Ly)",\
               A=gamma, Ly=Ly, Lw=Lw, w=w, t=t)
    dWM_dy.interpolate(dWM_dy_expr)                                  # d\tilde{R}/dy

#____________________________________ Depth ___________________________________#
h_n0.assign(H)                                               # h(x,y;t=0) = H(x)

#________________ Velocity pot. at the surface: phi(x,y,z=h;t) ________________#
psi_1_n0.assign(0.0)                                       # \psi_1(x,y;t=0) = 0

#________________ Velocity pot. in depth: phi(x,y,z<h;t=0) = 0 ________________#
hat_psi_n0.interpolate(Expression(("0.0","0.0","0.0","0.0")))

#___________________________ Initial mixed function ___________________________#
if scheme=="SE": #__________________ Symplectic-Euler scheme __________________#
    w_n1.interpolate(\
         Expression(("H0 - 0.5*(1+copysign(1.0,x[0]-xb))*slope*(x[0]-xb)", \
                     "0.0","0.0", "0.0","0.0"), \
                     H0=H0, xb=xb, slope=sb))
else: #_________________________ Stormer-Verlet scheme ________________________#
    w_half.interpolate(\
           Expression(("H0 - 0.5*(1+copysign(1.0,x[0]-xb))*slope*(x[0]-xb)", \
                       "0.0","0.0", "0.0","0.0"), \
                       H0=H0, xb=xb, slope=sb))


"""
    ************************
    * Compute the matrices *
    ************************ """
#_______ Initialization ______#
A = np.eye(Nz,Nz)*0.0
M = np.eye(Nz,Nz)*0.0
D = np.eye(Nz,Nz)*0.0
S = np.eye(Nz,Nz)*0.0
Ik = np.eye(Nz,1)*0.0

#____ Filling the matrices ___#
for i in range(0,Nz):
    for j in range(0,Nz):
        A[i,j]=A_ij(i,j,n_z,H0)
        M[i,j]=M_ij(i,j,n_z,H0)
        D[i,j]=D_ij(i,j,n_z,H0)
        S[i,j]=S_ij(i,j,n_z,H0)
    Ik[i] = I_i(i,n_z,H0)

#________ Submatrices ________#
A11 = A[0,0]
A1N = as_tensor(A[0,1:])
AN1 = as_tensor(A[1:,0])
ANN = as_tensor(A[1:,1:])

M11 = M[0,0]
M1N = as_tensor(M[0,1:])
MN1 = as_tensor(M[1:,0])
MNN = as_tensor(M[1:,1:])

D11 = D[0,0]
D1N = as_tensor(D[0,1:])
DN1 = as_tensor(D[1:,0])
DNN = as_tensor(D[1:,1:])

S11 = S[0,0]
S1N = as_tensor(S[0,1:])
SN1 = as_tensor(S[1:,0])
SNN = as_tensor(S[1:,1:])

I1 = Ik[0,0]
IN=as_tensor(Ik[1:,0])


a_hat_psi =((h_n1/(Lw-WM))*(Lw*Lw+((x_coord-Lw)*dWM_dy)**2)*elem_mult(delta_hat_psi.dx(0), dot(MNN,hat_psi.dx(0)))\
            +(Lw-WM)*h_n1*elem_mult(delta_hat_psi.dx(1),dot(MNN,hat_psi.dx(1)))\
            +(x_coord-Lw)*dWM_dy*h_n1*(elem_mult(dot(hat_psi.dx(0),MNN),delta_hat_psi.dx(1))\
                                       +elem_mult(delta_hat_psi.dx(0),dot(MNN,hat_psi.dx(1))))\
            -( (Lw*Lw+((x_coord-Lw)*(dWM_dy))**2)*h_n1.dx(0)/(Lw-WM) \
              +(x_coord-Lw)*dWM_dy*h_n1.dx(1))*(elem_mult(delta_hat_psi, dot(DNN.T,hat_psi.dx(0))) \
                                              + elem_mult(delta_hat_psi.dx(0),dot(DNN,hat_psi))) \
            -((Lw-WM)*h_n1.dx(1) + (x_coord-Lw)*dWM_dy*h_n1.dx(0))*(elem_mult(delta_hat_psi, dot(DNN.T,hat_psi.dx(1)))\
                                                                   +elem_mult(delta_hat_psi.dx(1),dot(DNN,hat_psi)))\
            +(1.0/h_n1)*( (Lw*Lw+((x_coord-Lw)*(dWM_dy))**2)*(h_n1.dx(0)**2)/(Lw-WM) +(Lw-WM)*(h_n1.dx(1)**2) \
                         +2.0*h_n1.dx(0)*h_n1.dx(1)*(x_coord-Lw)*dWM_dy)*elem_mult(delta_hat_psi,dot(SNN,hat_psi))\
            + ((Lw-WM)*H0*H0/h_n1)*elem_mult(delta_hat_psi,dot(ANN,hat_psi)))


L_hat_psi =-((h_n1/(Lw-WM))*(Lw*Lw+((x_coord-Lw)*dWM_dy)**2)*elem_mult(delta_hat_psi.dx(0), MN1*psi_1_n1.dx(0))\
              +(Lw-WM)*h_n1*elem_mult(delta_hat_psi.dx(1), MN1*psi_1_n1.dx(1))\
              +(x_coord-Lw)*dWM_dy*h_n1*(elem_mult(delta_hat_psi.dx(1), MN1*psi_1_n1.dx(0))\
                                         +elem_mult(delta_hat_psi.dx(0),MN1*psi_1_n1.dx(1)))\
              -( (Lw*Lw+((x_coord-Lw)*(dWM_dy))**2)*h_n1.dx(0)/(Lw-WM) \
                +(x_coord-Lw)*dWM_dy*h_n1.dx(1))*(elem_mult(delta_hat_psi, D1N*psi_1_n1.dx(0)) \
                                                  + elem_mult(delta_hat_psi.dx(0),DN1*psi_1_n1)) \
              -((Lw-WM)*h_n1.dx(1) + (x_coord-Lw)*dWM_dy*h_n1.dx(0))*(elem_mult(delta_hat_psi, D1N*psi_1_n1.dx(1))\
                                                                      +elem_mult(delta_hat_psi.dx(1),DN1*psi_1_n1))\
              +(1.0/h_n1)*( (Lw*Lw+((x_coord-Lw)*(dWM_dy))**2)*(h_n1.dx(0)**2)/(Lw-WM) +(Lw-WM)*(h_n1.dx(1)**2) \
                           +2.0*h_n1.dx(0)*h_n1.dx(1)*(x_coord-Lw)*dWM_dy)*elem_mult(delta_hat_psi,SN1*psi_1_n1)\
              + ((Lw-WM)*H0*H0/h_n1)*elem_mult(delta_hat_psi,AN1*psi_1_n1))

hat_psi_BC = -(Lw*dWM_dt*h_n1*elem_mult(delta_hat_psi,IN))

A_hat = sum((a_hat_psi[ind])*dx for ind in range(0,n_z))
L_hat = sum((L_hat_psi[ind])*dx for ind in range(0,n_z)) + sum((hat_psi_BC[ind])*ds(1) for ind in range(0,n_z))

# delta psi_1:
WF_h = (delta_psi*(h_n1-h_n0)*H0*(Lw-WM)/dt \
        -((h_n1/(Lw-WM))*(Lw*Lw+((x_coord-Lw)*dWM_dy)**2)*(psi_1_n0.dx(0)*M11 \
                                                           + dot(hat_psi_star.dx(0),MN1))*delta_psi.dx(0)\
             +(Lw-WM)*h_n1*(psi_1_n0.dx(1)*M11+dot(hat_psi_star.dx(1),MN1))*delta_psi \
             +(x_coord-Lw)*dWM_dy*h_n1*(delta_psi.dx(0)*(M11*psi_1_n0.dx(1) + dot(M1N,hat_psi_star.dx(1))) \
                                        + delta_psi.dx(1)*(M11*psi_1_n0.dx(0) + dot(MN1, hat_psi_star.dx(0))))\
             -( (1/(Lw-WM))*(Lw*Lw+((x_coord-Lw)*dWM_dy)**2)*h_n1.dx(0) \
               +(x_coord-Lw)*dWM_dy*h_n1.dx(1))*( delta_psi.dx(0)*(D11*psi_1_n0 + dot(D1N,hat_psi_star)) \
                                                 +delta_psi*(psi_1_n0.dx(0)*D11 + dot(hat_psi_star.dx(0),DN1)))\
             -( (Lw-WM)*h_n1.dx(1) + (x_coord-Lw)*dWM_dy*h_n1.dx(0))*( delta_psi.dx(1)*(D11*psi_1_n0 + dot(D1N,hat_psi_star))\
                                                                +delta_psi*(psi_1_n0.dx(1)*D11 + dot(hat_psi_star.dx(1),DN1)))\
             +(1/h_n1)*( (1/(Lw-WM))*(Lw*Lw+((x_coord-Lw)*dWM_dy)**2)*(h_n1.dx(0)**2) +(Lw-WM)*(h_n1.dx(1)**2) \
                        + 2.0*h_n1.dx(0)*h_n1.dx(1)*(x_coord-Lw)*dWM_dy)*(psi_1_n0*S11 + dot(hat_psi_star,SN1))*delta_psi\
             +((Lw-WM)*H0*H0/h_n1)*(psi_1_n0*A11 + dot(hat_psi_star,AN1))*delta_psi \
             -delta_psi*H0*(x_coord-Lw)*dWM_dt*h_n1.dx(0)))*dx \
- (delta_psi*Lw*dWM_dt*h_n1*I1)*ds(1)

WF_hat_psi_star= ((h_n1/(Lw-WM))*(Lw*Lw+((x_coord-Lw)*dWM_dy)**2)*elem_mult(delta_hat_star.dx(0),(MN1*psi_1_n0.dx(0)\
                                                                                                  +dot(MNN,hat_psi_star.dx(0))))\
                  +(Lw-WM)*h_n1*elem_mult(delta_hat_star.dx(1),(MN1*psi_1_n0.dx(1)+dot(MNN,hat_psi_star.dx(1))))\
                  +(x_coord-Lw)*dWM_dy*h_n1*(elem_mult((dot(hat_psi_star.dx(0),MNN)+psi_1_n0.dx(0)*M1N),delta_hat_star.dx(1))\
                                            +elem_mult(delta_hat_star.dx(0),(MN1*psi_1_n0.dx(1)+ dot(MNN,hat_psi_star.dx(1)))))\
                  -( (Lw*Lw+((x_coord-Lw)*(dWM_dy))**2)*h_n1.dx(0)/(Lw-WM) \
                    +(x_coord-Lw)*dWM_dy*h_n1.dx(1))*(elem_mult(delta_hat_star, (psi_1_n0.dx(0)*D1N+ dot(DNN.T,hat_psi_star.dx(0)))) \
                                                    + elem_mult(delta_hat_star.dx(0),(DN1*psi_1_n0+dot(DNN,hat_psi_star)))) \
                  -((Lw-WM)*h_n1.dx(1) + (x_coord-Lw)*dWM_dy*h_n1.dx(0))*(elem_mult(delta_hat_star, (D1N*psi_1_n0.dx(1)\
                                                                                                +dot(DNN.T,hat_psi_star.dx(1))))\
                                                                          +elem_mult(delta_hat_star.dx(1),(DN1*psi_1_n0 \
                                                                                                     + dot(DNN,hat_psi_star))))\
                  +(1.0/h_n1)*( (Lw*Lw+((x_coord-Lw)*(dWM_dy))**2)*(h_n1.dx(0)**2)/(Lw-WM) +(Lw-WM)*(h_n1.dx(1)**2) \
                               +2.0*h_n1.dx(0)*h_n1.dx(1)*(x_coord-Lw)*dWM_dy)*elem_mult(delta_hat_star,(SN1*psi_1_n0\
                                                                                                   + dot(SNN,hat_psi_star)))\
                  + ((Lw-WM)*H0*H0/h_n1)*elem_mult(delta_hat_star,(AN1*psi_1_n0+dot(ANN,hat_psi_star))))

WF_hat_BC = (Lw*dWM_dt*h_n1*elem_mult(delta_hat_star,IN))

WF_h_psi = WF_h + sum((WF_hat_psi_star[ind])*dx for ind in range(0,n_z)) + sum((WF_hat_BC[ind])*ds(1) for ind in range(0,n_z))


# delta h
A_psi_s = (H0*delta_h*(Lw-WM_n1)*psi_1)*dx

L_psi_s = -(-H0*delta_h*(Lw-WM)*psi_1_n0 \
            +dt*(delta_h*( ((Lw*Lw+((x_coord-Lw)*dWM_dy)**2)/(2.0*(Lw-WM)))*((psi_1_n0.dx(0)**2)*M11 \
                                                                             +dot(hat_psi_star.dx(0), (2.0*MN1*psi_1_n0.dx(0)\
                                                                                                       +dot(MNN,hat_psi_star.dx(0)))))\
                          +0.5*(Lw-WM)*( (psi_1_n0.dx(1)**2)*M11 + dot(hat_psi_star.dx(1), (2.0*MN1*psi_1_n0.dx(1) + dot(MNN,hat_psi_star.dx(1)))))\
                          +(x_coord-Lw)*dWM_dy*( psi_1_n0.dx(0)*(M11*psi_1_n0.dx(1) + dot(MN1,hat_psi_star.dx(1))) \
                                                + dot(hat_psi_star.dx(0), (MN1*psi_1_n0.dx(1) + dot(MNN,hat_psi_star.dx(1))))))\
                 -((1.0/(Lw-WM))*(Lw*Lw+((x_coord-Lw)*dWM_dy)**2)*delta_h.dx(0) \
                   + (x_coord-Lw)*dWM_dy*delta_h.dx(1))*( psi_1_n0.dx(0)*(D11*psi_1_n0 + dot(D1N,hat_psi_star)) \
                                                         +dot(hat_psi_star.dx(0), (DN1*psi_1_n0 + dot(DNN, hat_psi_star))))\
                 -((Lw-WM)*delta_h.dx(1) + (x_coord-Lw)*delta_h.dx(0)*dWM_dy)*( psi_1_n0.dx(1)*(D11*psi_1_n0 + dot(D1N,hat_psi_star))\
                                                                               +dot(hat_psi_star.dx(1), (DN1*psi_1_n0 + dot(DNN, hat_psi_star))))\
                 +(1.0/h_n1)*(delta_h.dx(0)*((1.0/(Lw-WM))*h_n1.dx(0)*(Lw*Lw+((x_coord-Lw)*dWM_dy)**2) + h_n1.dx(1)*(x_coord-Lw)*dWM_dy)\
                              -(delta_h/h_n1)*( (Lw*Lw+((x_coord-Lw)*dWM_dy)**2)*(h_n1.dx(0)**2)/(2.0*(Lw-WM)) + 0.5*(Lw-WM)*(h_n1.dx(1)**2)\
                                               + h_n1.dx(0)*h_n1.dx(1)*(x_coord-Lw)*dWM_dy )\
                              + delta_h.dx(1)*( (Lw-WM)*h_n1.dx(1) + h_n1.dx(0)*(x_coord-Lw)*dWM_dy))*(psi_1_n0*psi_1_n0*S11 \
                                                                                                       + 2.0*dot(hat_psi_star,SN1)*psi_1_n0\
                                                                                                       +dot(hat_psi_star,dot(SNN,hat_psi_star)))\
                 -(0.5*delta_h*(Lw-WM)*H0*H0/(h_n1**2))*(psi_1_n0*psi_1_n0*A11 + 2.0*dot(hat_psi_star,AN1)*psi_1_n0 \
                                                         + dot(hat_psi_star,dot(ANN,hat_psi_star)))\
                 +H0*g*(Lw-WM)*delta_h*(h_n1-H) - H0*psi_1_n0*(x_coord-Lw)*dWM_dt*delta_h.dx(0)))*dx - dt*(Lw*dWM_dt*delta_h*(psi_1_n0*I1 + dot(hat_psi_star,IN)))*ds(1)
                   


"""
    ********************************************
    *             Define the solvers            *
    ******************************************** """
#___ Solver parameters ___#

if scheme=="SE":
    #param_h = { }
    # param_psi = { }
    # param_hat_psi = {"ksp_converged_reason":True,"pc_type": "fieldsplit",'pc_fieldsplit_type': 'schur','pc_fieldsplit_schur_fact_type': 'FULL'}
    
    #___ Variational solver for h (and hat_psi^*) ___#
    h_problem = NonlinearVariationalProblem(WF_h_psi, w_n1)
    h_solver = NonlinearVariationalSolver(h_problem)#, solver_parameters=param)

    #___ Variational solver for psi_1 ___#
    psi_problem = LinearVariationalProblem(A_psi_s, L_psi_s, psi_1_n1)
    psi_solver = LinearVariationalSolver(psi_problem)#, solver_parameters=param_psi)
    
    #___ Variational solver for hat_psi ___#
    hat_psi_problem = LinearVariationalProblem(A_hat, L_hat, hat_psi_n1)
    hat_psi_solver = LinearVariationalSolver(hat_psi_problem)#, solver_parameters=param_hat_psi)

#hat_psi_solver.solve()
h_solver.solve()
#psi_solver.solve()
