from firedrake import *



Nx = 16
Nz = Nx
mesh = UnitSquareMesh(Nx,Nz)

# Create consistent direction of fluxes for parallel implementation

fs = FunctionSpace(mesh, "RT", 1)
f = Function(fs)
f.dat.data[:] = 1.

# Index for direction of gravity
z_index = mesh.geometric_dimension()-1

# Declare timestep
dt = 1./( pow(Nx,2) )

# Declare initial and end time
# Period of waves considered in test is 1s
# We will consider 3 periods initially
t = 0.
end = 1.

# Declare order of the basis in the elements
# Test problem will consider order 0 - a finite volume scheme
order_basis = 2
quad_degree = 4*order_basis

# Declare flux indicator function
theta = Constant(0.6)



# Define Background Density

Ro = FunctionSpace(mesh, "DG", order_basis)

r_0_project = Function(Ro)
r_0_project.interpolate(Expression( "exp(-3.0*x[1])" ))

dr_0= Function(Ro)
dr_0.interpolate(Expression( "-3.0*exp(-3.0*x[1])" ))

N_sq = Constant(2.0)  
# Currently analytically derived. Non constant N_sq has issues.
# Using Function(Ro) sets it to be a numpy array
# Need a workaround, for including it in other terms such as dHdp

Ma = Constant(1.0) # Mach Number
c_sq = Constant(1.0) # Speed of sound
g = Constant(1.0) # gravity
m = 2 * pi
sigma =   pow( 0.5*(8*pi**2 + 0.25 * ( N_sq + 1)**2 +  pow((8*pi**2 + 0.25 * ( N_sq + 1)**2)**2 - 16*N_sq*pi**2,0.5)),0.5)




# Declare function spaces on the mesh
# and create mixed function space
# for coupled Poisson Bracket.

V = VectorFunctionSpace(mesh, "DG", order_basis)
R = FunctionSpace(mesh, "DG", order_basis)

W = V*R*R*V*R*R



# Initial conditions

# Function space
w0 = Function(W)


# Interpolate expressions
u0,rho0, p0, dHdu0, dHdrho0, dHdp0 = w0.split()

u0.interpolate(Expression(["exp( -0.5*(N_sq+1)*x[1] )* sin ( m*x[0] )*(2*pi/ ( 4*pow(pi,2) - pow(sigma,2) ) )*( -2*pi*cos(m*x[1]) - 0.5 *(N_sq-1))*sin( m*x[1]) *sin(sigma*0.125)","exp( -0.5*(N_sq+1)*x[1] )*sin(m*x[1])*cos(m*x[0])*sin(sigma*0.125)"], N_sq=N_sq, m=m , sigma=sigma ))

rho0.interpolate(Expression("exp( -0.5*(N_sq+1)*x[1] )*( sigma / ( 4*pow(pi,2)- pow(sigma,2)) )*cos( m*x[0] )*( (0.5*(N_sq + 1) - (4*N_sq*pow(pi,2))/pow(sigma,2))*sin( m*x[1]) - 2*pi*cos( m*x[1] ) )*cos( sigma*0.125 )", N_sq=N_sq, m=m , sigma=sigma ))

p0.interpolate(Expression("exp( -0.5*(N_sq+1)*x[1] )*( sigma / ( 4*pow(pi,2) - pow(sigma,2)) )*cos( m*x[0] )*( -0.5*(N_sq - 1)*sin( m*x[1]) - 2*pi*cos( m*x[1] ) )*cos( sigma*0.125 )", N_sq=N_sq, m=m , sigma=sigma ))

# Project initial variational derivatives

q = TrialFunction(V)
q_test = TestFunction(V)

a_u_project = dot( q ,  q_test )*dx( degree=quad_degree )
L_u_project = dot ( u0/r_0_project ,  q_test )*dx( degree=quad_degree )

solve ( a_u_project == L_u_project , dHdu0 , solver_parameters={'ksp_rtol': 1e-14} )

b = TrialFunction(R)
b_test = TestFunction(R)

a_rho_project = b*b_test*dx( degree=quad_degree )
L_rho_project = g**2/(r_0_project*N_sq)*(rho0 - Ma*p0/c_sq)*b_test*dx( degree=quad_degree )

solve ( a_rho_project == L_rho_project , dHdrho0, solver_parameters={'ksp_rtol': 1e-14} )

a_p_project = b*b_test*dx( degree=quad_degree )
L_p_project1 = g**2/(r_0_project*N_sq)*(Ma**2*p0/c_sq**2 - Ma*rho0/c_sq)*b_test*dx( degree=quad_degree )
L_p_project2 = (Ma*p0/(c_sq*r_0_project))*b_test*dx( degree=quad_degree )
L_p_project = L_p_project1 + L_p_project2

solve ( a_p_project == L_p_project , dHdp0, solver_parameters={'ksp_rtol': 1e-14} )


# Assemble initial energy
E0 = assemble ( (0.5/r_0_project*inner(u0,u0) + g**2/(2*r_0_project*N_sq)*(rho0 - Ma*p0/c_sq)**2 + (Ma/(2*r_0_project*c_sq))*p0**2 )*dx(degree=quad_degree))

print "The initial energy is %g " % (E0)

# Set up internal boundary normal on mesh 
# Needed for numerical fluxes in bilinear form

n = FacetNormal(mesh)


# Create bilinear form for linear solver
# Bilinear problem is of the form 
# a(u,v) = L(v)
# Using our Poisson bracket and an implicit midpoint rule
# we see that
# a(u,v) = u^{n+1}*dFdu_vec + rho^{n+1)*dFdrho - 0.5*dt*PB(u^{n+1}, rho^{n+1})
# L(v) = u^{n}*dFdu_vec + rho^{n)*dFdrho + 0.5*dt*PB(u^{n}, rho^{n})


# We note that there are no boundary surface integrals ds, as we require
# the normal of the variational derivative and test function to vanish 
# at the boundary.


# Define trial and test functions on the mixed space

(u, rho, p , dHdu, dHdrho, dHdp) = TrialFunctions(W)
(dFdu_vec, dFdrho, dFdp, dFdu_project, dFdrho_project, dFdp_project) = TestFunctions(W)

#Define varitional derivatives

(u0,rho0, p0,  dHdu0, dHdrho0, dHdp0) = split(w0)




# Define discrete divergence operator
def div_u(u, p):
	return (dot(u, grad(p)))*dx( domain=p.ufl_domain() ) + (jump(p)*dot(conditional(ge(dot(f, n)('+'), 0), u('+')*theta + u('-')*(1 - theta), u('-')*theta + u('+')*(1 - theta)),  n('-')))*dS


 


L0 = (dot(u0, dFdu_vec) + rho0*dFdrho + p0*dFdp )*dx
L1 = -div_u(dFdu_vec, r_0_project * dHdrho0)
L2 = div_u(dHdu0, r_0_project * dFdrho)
L3 = div_u( (c_sq / Ma )* dHdu0 , r_0_project * dFdp )
L4 = -div_u((c_sq / Ma )*dFdu_vec, r_0_project * dHdp0)
L5 = (dr_0*(dHdrho0*dFdu_vec[1] - dFdrho*dHdu0[1]) )*dx( domain=p.ufl_domain() )
L6 = (g*r_0_project*(dHdp0*dFdu_vec[1] - dFdp*dHdu0[1] ))*dx( domain=p.ufl_domain() )

L = L0 + 0.5 * dt * ( L1 + L2 +L3 + L4 + L5 + L6  )

a1 = ( dot(dHdu , dFdu_project) - dot(u / r_0_project ,  dFdu_project) ) * dx( degree=quad_degree )
a2 = ( dHdrho * dFdrho_project - (g**2/(r_0_project*N_sq)*(rho0 - Ma*p0/c_sq) )* dFdrho_project ) *dx( degree=quad_degree )
a3 = ( dHdp * dFdp_project - ((Ma*p0/(c_sq*r_0_project)) + g**2/(r_0_project*N_sq)*(Ma**2*p0/c_sq**2 - Ma*rho0/c_sq) )* dFdp_project ) *dx( degree=quad_degree )

a = derivative(L0 - 0.5 * dt * ( L1  + L2 +L3 + L4 + L5 + L6   ), w0) + a1 + a2 +a3
 

# Storage for visualisation
outfile = File('./Results/compressible_stratified_results.pvd')

u0,rho0,p0,dHdu0,dHdrho0,dHdp0 = w0.split()

u0.rename("Velocity")
rho0.rename("Density")
p0.rename("Pressure")


# Output initial conditions
outfile.write(u0,rho0,p0, time = t)


out = Function(W)
# File for energy output
E_file = open('./Results/energy.txt', 'w')


problem = LinearVariationalProblem( a , L , out)
solver = LinearVariationalSolver(problem, solver_parameters={'ksp_rtol': 1e-14})



# Solve loop

while (t < end):
    # Update time
    t+= dt
 
    
    solver.solve()
    u, rho,p, dHdu, dHdrho, dHdp = out.split()

    # Assign appropriate name in results file
    u.rename("Velocity")
    rho.rename("Density")
    p.rename("Pressure")
 
    # Output results
    outfile.write(u, rho,p, time =t)
 
    # Assign output as previous timestep for next time update
    u0.assign(u)
    rho0.assign(rho)
    p0.assign(p)
    dHdu0.assign(dHdu)
    dHdrho0.assign(dHdrho)
    dHdp0.assign(dHdp)
 
    # Assemble initial energy
    E = assemble ( (0.5/r_0_project*inner(u0,u0) + g**2/(2*r_0_project*N_sq)*(rho0 - Ma*p0/c_sq)**2 + (Ma/(2*r_0_project*c_sq))*p0**2 )*dx(degree=quad_degree))
 
    E_file.write('%-10s %-10s\n' % (t,abs((E-E0)/E0)))
    # Print time and energy drift, drift should be around machine precision.
    print "At time %g, energy drift is %g" % (t, E-E0)




# Close energy write
E_file.close()
