#================================================
#
#  Advective-diffusive-reactive system example
#  By: Justin Chang
#
#================================================
#import sys
#xseed = int(sys.argv[1])
#yseed = int(sys.argv[2])
xseed = 50
yseed = 25
OS = 1

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
from mpi4py import MPI
import numpy as np
import time
comm = MPI.COMM_WORLD
rank = comm.Get_rank()

#========================
#  Create mesh 
#========================
mesh = RectangleMesh(xseed,yseed,2,1,quadrilateral=True)
V = VectorFunctionSpace(mesh, "CG", 1)
Q = FunctionSpace(mesh, "CG", 1)
W = Q*Q
u_A,u_B = TrialFunctions(W)
v_A,v_B = TestFunctions(W)
u0 = Function(W)
u0_A,u0_B = u0.split()

#========================
#  Advection velocity
#========================
velocity = Function(V)
velocity.project(Expression((".1 + 0.008*pi*cos(4*pi*x[0]*0.5 - 0.5*pi)*cos(pi*x[1]) + 0.01*pi*cos(5*pi*x[0]*0.5 - 0.5*pi)*cos(5*pi*x[1]) + 0.01*pi*cos(10*pi*x[0]*0.5 - 0.5*pi)*cos(10*pi*x[1])","0.016*pi*sin(2*pi*x[0] - 0.5*pi)*sin(pi*x[1]) + 0.005*pi*sin(5*pi*x[0]*0.5 - 0.5*pi)*sin(5*pi*x[1]) + 0.005*pi*sin(5*pi*x[0] - 0.5*pi)*sin(10*pi*x[1])")))

#========================
#  Diffusivity tensor
#========================
normv = inner(velocity,velocity)
Id = Identity(mesh.ufl_cell().geometric_dimension())
alpha_t = Constant(0.0001)
alpha_L = Constant(1)
D = alpha_t*normv*Id+(alpha_L-alpha_t)*outer(velocity,velocity)/normv

#========================
#  Volumetric source
#========================
f_A = Function(W.sub(0))
f_B = Function(W.sub(1))
f_A.project(Expression(("x[1] <= 0.75 && x[1] >= 0.70 && x[0] >= 0.2 && x[0] <= 0.25 ? 0.0 : 0.0")))
f_B.project(Expression(("x[1] <= 0.30 && x[1] >= 0.25 && x[0] >= 0.2 && x[0] <= 0.25 ? 0.0 : 0.0")))

#========================
#  Time-stepping
#========================
dt = 0.02
T = 1
t = dt

#========================
#  Solver parameters
#========================
SUPG_parameters = { 
  'ksp_type': 'gmres',
  'pc_type':'bjacobi'
}
OS_parameters = { 
  'ksp_type': 'cg',
  'pc_type':'bjacobi'
}
R_parameters = { 
  'ksp_type': 'gmres',
  'pc_type':'bjacobi'
}

#========================
#  Output files
#========================
outFile_velocity = File("2D_plume_ADR_ex1/vel.pvd")
outFile_velocity << velocity
outFile_A = File("2D_plume_ADR_ex1/C_A.pvd")
outFile_B = File("2D_plume_ADR_ex1/C_B.pvd")
outFile_C = File("2D_plume_ADR_ex1/C_C.pvd")
outFile_psiA = File("2D_plume_ADR_ex1/psi_A.pvd")
outFile_psiB = File("2D_plume_ADR_ex1/psi_B.pvd")

#========================
#  AD - SUPG
#========================
if OS == 0:
  # SUPG stabilization
  h = CellSize(mesh)
  vnorm = sqrt(dot(velocity,velocity))
  Pe_A = h/(2*vnorm)*dot(velocity,grad(v_A))
  Pe_B = h/(2*vnorm)*dot(velocity,grad(v_B))
  a_A_r = Pe_A*(u_A + dt*(dot(velocity,grad(u_A)) - div(D*grad(u_A))))*dx
  a_B_r = Pe_B*(u_B + dt*(dot(velocity,grad(u_B)) - div(D*grad(u_B))))*dx
  L_A_r = Pe_A*(u0_A + dt*f_A)*dx
  L_B_r = Pe_B*(u0_B + dt*f_B)*dx
  # SUPG weak form
  a_SUPG = a_A_r + a_B_r + v_A*u_A*dx + dt*(v_A*dot(velocity,grad(u_A))*dx + dot(grad(v_A),D*grad(u_A))*dx) + dt*(v_B*dot(velocity,grad(u_B))*dx + dot(grad(v_B),D*grad(u_B))*dx)
  L_SUPG = L_A_r + L_B_r + v_A*u0_A*dx + dt*v_A*f_A*dx + v_B*u0_B*dx + dt*v_B*f_B*dx

#========================
#  AD - Operator split
#========================
else:
  # Advection weak form
  a_A_psiA = (v_A + dt*dot(velocity,grad(v_A)))*(u_A + dt*dot(velocity,grad(u_A)))*dx
  a_A_psiB = (v_B + dt*dot(velocity,grad(v_B)))*(u_B + dt*dot(velocity,grad(u_B)))*dx
  L_A_psiA = (v_A + dt*dot(velocity,grad(v_A)))*(u0_A + dt*f_A)*dx
  L_A_psiB = (v_B + dt*dot(velocity,grad(v_B)))*(u0_B + dt*f_B)*dx
  a_A = a_A_psiA + a_A_psiB
  L_A = L_A_psiA + L_A_psiB
  # Diffusion weak form
  a_D_psiA = v_A*u_A*dx + dt*dot(grad(v_A),D*grad(u_A))*dx
  a_D_psiB = v_B*u_B*dx + dt*dot(grad(v_B),D*grad(u_B))*dx
  L_D_psiA = v_A*u0_A*dx
  L_D_psiB = v_B*u0_B*dx
  a_D = a_D_psiA + a_D_psiB
  L_D = L_D_psiA + L_D_psiB

#========================
#  Reactions - check this
#========================
W_R = Q*Q*Q
dc_A,dc_B,dc_C = TrialFunctions(W_R)
q_A,q_B,q_C = TestFunctions(W_R)
c = Function(W_R)
c_A,c_B,c_C = c.split()

# Residual and Jacobian form
F = (q_A*(c_A + c_C - u0_A) + q_B*(c_B + c_C - u0_B) + q_C*(c_A*c_B - c_C))*dx
J = (q_A*(dc_A + dc_C) + q_B*(dc_B + dc_C) + q_C*(c_B*dc_A + c_A*dc_B - dc_C))*dx
problem_R = NonlinearVariationalProblem(F,c,J=J)
solver_R = NonlinearVariationalSolver(problem_R,solver_parameters=R_parameters,options_prefix="R_")

#========================
#  Dirichlet BCs
#========================
bc1 = DirichletBC(W.sub(0), Expression(("x[1] <= 0.5 ? 1.0 : 0.0")), (1))
bc2 = DirichletBC(W.sub(1), Expression(("x[1] > 0.5 ? 1.0 : 0.0")), (1))
bc3 = DirichletBC(W.sub(0), 0.0, (2,3,4))
bc4 = DirichletBC(W.sub(1), 0.0, (2,3,4))
if OS == 0:
  bc_SUPG = [bc1,bc2,bc3,bc4]
else:
  bc_A = [bc1,bc2]
  bc_D = [bc1,bc2,bc3,bc4]

#========================
#  Initial condition
#========================
u0.assign(0)
u0_A.assign(0)
u0_B.assign(0)
c.assign(0)
c_A.assign(0)
c_B.assign(0)
c_C.assign(0)

#========================
#  AD linear operators
#========================
initialTime = time.time()
if OS == 0:
  A_SUPG = assemble(a_SUPG,bcs=bc_SUPG)
  solver_SUPG = LinearSolver(A_SUPG,options_prefix="SUPG_",solver_parameters=SUPG_parameters)
else:
  A_D = assemble(a_D,bcs=bc_D)
  A_A = assemble(a_A,bcs=bc_A)
  solver_A = LinearSolver(A_A,options_prefix="A_",solver_parameters=OS_parameters)
  solver_D = LinearSolver(A_D,options_prefix="D_",solver_parameters=OS_parameters)

#=======================
#  AD solver functions
#=======================
def OSsolve(solver_O,L_O,bc_O,u0):
  b = assemble(L_O,bcs=bc_O)
  solver_O.solve(u0,b)
  return u0

#=======================
#  Time-stepping
#=======================
while t < T:
  if rank == 0:
    print '\tTime = %f' % t
  # Advection-diffusion
  if OS == 0:
    u0 = OSsolve(solver_SUPG,L_SUPG,bc_SUPG,u0)
    solA,solB = u0.split()
    u0_A.assign(solA)
    u0_B.assign(solB)
  else:
    u0 = OSsolve(solver_A,L_A,bc_A,u0)
    solA,solB = u0.split()
    u0_A.assign(solA)
    u0_B.assign(solB)
    u0 = OSsolve(solver_D,L_D,bc_D,u0)
    solA,solB = u0.split()
    u0_A.assign(solA)
    u0_B.assign(solB)
  
  # Output total concentrations
  outFile_psiA << u0_A
  outFile_psiB << u0_B

  # Solve reactions 
  solver_R.solve()
  
  # Output geochemical concentrations
  conc_A,conc_B,conc_C = c.split() 
  outFile_A << conc_A
  outFile_B << conc_B
  outFile_C << conc_C
  t += dt

#=======================
#  End simulation
#=======================
getTime = time.time()-initialTime
if rank == 0:
  print '\tWall-clock time: %.3e seconds' % getTime
