# Solve Nonlinear Benney-Luke equations

from firedrake import *

# Parameter values
Lx = 1      # Need to be integers to define number of grid points
Ly = 20     # values depend on mu

T = 20.0
dt = 0.001
t = 0

mu = 0.04
epsilon = 0.25  # epsilon=0 is the linear problem 
Ampl = 0.1
kx = 2
ky = 2
k2 = (kx*pi/Lx)**2 + (ky*pi/Ly)**2
omega = sqrt(k2*(1+2*mu/3.0*k2))/(1+0.5*mu*k2)

# Create mesh
Nx = Lx*8
Ny = Ly*8
m = UnitIntervalMesh(Nx)            # For quadrilateral mesh
mesh = ExtrudedMesh(m, layers=Ny)
#mesh = UnitSquareMesh(Nx,Ny)       # For triangular mesh
coords = mesh.coordinates
coords.dat.data[:,0] = Lx*coords.dat.data[:,0]
coords.dat.data[:,1] = Ly*coords.dat.data[:,1]

# Stretch mesh
d = 1.2   # value depends on mu
L = sqrt(d**2 - (0.5*Lx)**2)
slope = L/(0.5*Lx)
dy = [0]*(Nx+1)*(Ny+1)

for i in range(len(coords.dat.data)):
	if coords.dat.data[i,0]<0.5*Lx:
		dy[i] = slope*(coords.dat.data[i,0]-0.5*Lx)/Ny
	else:
	        dy[i] = slope*(0.5*Lx-coords.dat.data[i,0])/Ny
	coords.dat.data[i,1] = coords.dat.data[i,1] + i%(Ny+1)*dy[i]

# Define functions
V = FunctionSpace(mesh,"CG",2)

eta0 = Function(V).interpolate(Expression("Ampl*cos(kx*pi*x[0]/Lx)*cos(ky*pi*x[1]/Ly)", kx=kx, ky=ky, Lx=Lx, Ly=Ly, Ampl=Ampl))
phi0 = Function(V)
eta1 = Function(V)
phi1 = Function(V)
eta2 = Function(V)
phi2 = Function(V)
Lphi1 = Function(V)
etaR = Function(V)

Lphi = TrialFunction(V)
gamma = TestFunction(V)

# Define gravitational potential etaR
H0s = 0.45
h1 = (0.9-H0s)/(epsilon*H0s)
h0 = (0.43-H0s)/(epsilon*H0s)
g = 9.81
LL = 2.63
y1 = LL/(H0s/sqrt(mu))
y2 = (LL+0.2)/(H0s/sqrt(mu))
Ts = (0.9/2.5)/sqrt(H0s/g/mu)

#sign = lambda x : + (x>0) or -(x<0)
#sign = lambda x : cmp(x,0)
a = 1000
#etaR_expr = Expression("(h1-H0)*(0.5*(1+sign(y1-x[1])) + 0.25*(1+sign(y2-x[1]))*(1+sign(x[1]-y1))*(1-(x[1]-y1)/(y2-y1)))", y1=y1, y2=y2, h1=h1, H0=h0)
#etaR_expr = Expression("(h1-H0)*(0.5*(1+cmp(y1-x[1],0)) + 0.25*(1+cmp(y2-x[1],0))*(1+cmp(x[1]-y1,0))*(1-(x[1]-y1)/(y2-y1)))", y1=y1, y2=y2, h1=h1, H0=h0)
etaR_expr = Expression("(h1-H0)*(0.5*(1+tanh(a*(y1-x[1]))) + 0.25*(1+tanh(a*(y2-x[1])))*(1+tanh(a*(x[1]-y1)))*(1-(x[1]-y1)/(y2-y1)))", y1=y1, y2=y2, h1=h1, H0=h0, a=a)
etaR.interpolate(etaR_expr)

etaR_file = File("etaR.pvd")
etaR_file << etaR

eta0.assign(etaR)

order_time = 1  # =1 for Euler, =2 for Stormer-Verlet
if order_time == 1:
	# Solve equation for phi_{n+1}
	Fphi = ( gamma*(phi1-phi0)/dt + 0.5*mu*inner(grad(gamma),grad((phi1-phi0)/dt)) + gamma*(eta0-etaR) - 0.5*epsilon*phi1*div(gamma*grad(phi1)))*dx  ########### phi0 or phi1??

	phi_problem = NonlinearVariationalProblem(Fphi,phi1)
	phi_solver = NonlinearVariationalSolver(phi_problem,solver_parameters={'snes_monitor': True,'ksp_monitor': True,'snes_linesearch_monitor': True})

	# Solve Laplace equation Lphi = grad^4(phi)
	FLaplace = ( gamma*Lphi1 - inner(grad(gamma),grad(phi1)) )*dx

	Laplace_problem = NonlinearVariationalProblem(FLaplace,Lphi1)
	Laplace_solver = NonlinearVariationalSolver(Laplace_problem,solver_parameters={'snes_monitor': True,'ksp_monitor': True,'snes_linesearch_monitor': True})

	# Solve equation for eta_{n+1}
	Feta = ( gamma*(eta1-eta0)/dt + 0.5*mu*inner(grad(gamma),grad((eta1-eta0)/dt)) - (1+epsilon*eta0)*inner(grad(gamma),grad(phi1)) - 2.0/3.0*mu*inner(grad(gamma),grad(Lphi1)) )*dx

	eta_problem = NonlinearVariationalProblem(Feta,eta1)
	eta_solver = NonlinearVariationalSolver(eta_problem,solver_parameters={'snes_monitor': True,'ksp_monitor': True,'snes_linesearch_monitor': True})

elif order_time == 2:
	# Solve equation for phi_{n+1/2} - Implicit
	Fphi10 = ( gamma*(phi1-phi0)/(0.5*dt) + 0.5*mu*inner(grad(gamma),grad((phi1-phi0)/(0.5*dt))) + gamma*(eta0-etaR) - 0.5*epsilon*phi1*div(gamma*grad(phi1)))*dx

	phi_problem10 = NonlinearVariationalProblem(Fphi10,phi1)
	phi_solver10 = NonlinearVariationalSolver(phi_problem10,solver_parameters={'snes_monitor': True,'ksp_monitor': True,'snes_linesearch_monitor': True})

	# Solve Laplace equation Lphi = grad^4(phi)
	FLaplace = ( gamma*Lphi1 - inner(grad(gamma),grad(phi1)) )*dx

	Laplace_problem = NonlinearVariationalProblem(FLaplace,Lphi1)
	Laplace_solver = NonlinearVariationalSolver(Laplace_problem,solver_parameters={'snes_monitor': True,'ksp_monitor': True,'snes_linesearch_monitor': True})

	# Solve equation for eta_{n+1} - Implicit
	Feta20 = ( gamma*(eta2-eta0)/(0.5*dt) + 0.5*mu*inner(grad(gamma),grad((eta2-eta0)/(0.5*dt))) - (1+epsilon*eta0)*inner(grad(gamma),grad(phi1)) - 2.0/3.0*mu*inner(grad(gamma),grad(Lphi1))
												     - (1+epsilon*eta2)*inner(grad(gamma),grad(phi1)) - 2.0/3.0*mu*inner(grad(gamma),grad(Lphi1)) )*dx

	eta_problem20 = NonlinearVariationalProblem(Feta20,eta2)
	eta_solver20 = NonlinearVariationalSolver(eta_problem20,solver_parameters={'snes_monitor': True,'ksp_monitor': True,'snes_linesearch_monitor': True})

	# Solve equation for phi_{n+1} - Explicit
	Fphi21 = ( gamma*(phi2-phi1)/(0.5*dt) + 0.5*mu*inner(grad(gamma),grad((phi2-phi1)/(0.5*dt))) + gamma*(eta2-etaR) - 0.5*epsilon*phi1*div(gamma*grad(phi1)))*dx

	phi_problem21 = NonlinearVariationalProblem(Fphi21,phi2)
	phi_solver21 = NonlinearVariationalSolver(phi_problem21,solver_parameters={'snes_monitor': True,'ksp_monitor': True,'snes_linesearch_monitor': True})

phi_file = File('phi.pvd')
eta_file = File('eta.pvd')

phi_file << phi0
eta_file << eta0

'''# Initial energy
E0 = assemble( ( 0.5*(1+epsilon*eta0)*abs(grad(phi0))**2 + 0.5*eta0**2 + mu/3.0*div(grad(phi0))**2 )*dx )
E = E0
f = open('energy.txt', 'w')
time = str(t)
Et = str(abs(E-E0)/E0)
f.write('%-10s %s\n' % (time,Et))'''

set_E0 = 0
while(t<T-0.5*dt):
    t += dt
    print t, assemble(phi0*phi0*dx)

    if t<Ts:
	    etaR_expr.H0 = h1+(h0-h1)*(Ts-t)/Ts
    else:
	    etaR_expr.H0 = h1

    etaR.interpolate(etaR_expr)

    if order_time == 1:
	phi_solver.solve()
    	Laplace_solver.solve()    
    	eta_solver.solve()

	eta0.assign(eta1)
    	phi0.assign(phi1)

    elif order_time == 2:
    	phi_solver10.solve()
    	Laplace_solver.solve()    
    	eta_solver20.solve()
    	phi_solver21.solve()

    	eta0.assign(eta2)
    	phi0.assign(phi1)
    	phi1.assign(phi2)

    phi_file << phi1
    eta_file << eta0
    etaR_file << etaR

    # Monitor energy for t>Ts
    if t>Ts:
	set_E0 = set_E0 + 1
	if set_E0 == 1: 
		E0 = assemble( ( 0.5*(1+epsilon*eta0)*abs(grad(phi1))**2 + 0.5*eta0**2 + mu/3.0*div(grad(phi1))**2 )*dx )
		f = open('energy.txt', 'w')

    	E = assemble( ( 0.5*(1+epsilon*eta0)*abs(grad(phi1))**2 + 0.5*eta0**2 + mu/3.0*div(grad(phi1))**2 )*dx )
    	time = str(t)
    	Et = str(abs(E-E0)/E0)
    	f.write('%-10s %s\n' % (time,Et))

f.close()

