#================================================
#
#  Advective-diffusive-reactive system example
#  By: Justin Chang
#
#================================================
#import sys
#xseed = int(sys.argv[1])
#yseed = int(sys.argv[2])
xseed = 100
yseed = 50
opt_D = 0
opt_A = 0
#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
from mpi4py import MPI
from petsc4py import PETSc
from scipy.optimize import fsolve
import math
import numpy as np
import time
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
parameters["matnest"] = False

#========================
#  Create mesh 
#========================
mesh = RectangleMesh(xseed,yseed,2,1,quadrilateral=True)
V = VectorFunctionSpace(mesh, "CG", 1)
Q = FunctionSpace(mesh, "CG", 1)
W = Q*Q
u_A,u_B = TrialFunctions(W)
v_A,v_B = TestFunctions(W)
u0 = Function(W)
u0_A,u0_B = u0.split()

#========================
#  Advection velocity
#========================
velocity = Function(V)
velocity.project(Expression(("1 + 0.08*pi*cos(4*pi*x[0]*0.5 - 0.5*pi)*cos(pi*x[1]) + 0.1*pi*cos(5*pi*x[0]*0.5 - 0.5*pi)*cos(5*pi*x[1]) + 0.1*pi*cos(10*pi*x[0]*0.5 - 0.5*pi)*cos(10*pi*x[1])","0.16*pi*sin(2*pi*x[0] - 0.5*pi)*sin(pi*x[1]) + 0.05*pi*sin(5*pi*x[0]*0.5 - 0.5*pi)*sin(5*pi*x[1]) + 0.05*pi*sin(5*pi*x[0] - 0.5*pi)*sin(10*pi*x[1])")))

#========================
#  Diffusivity tensor
#========================
normv = inner(velocity,velocity)
Id = Identity(mesh.ufl_cell().geometric_dimension())
alpha_t = Constant(0.0001)
alpha_L = Constant(1)
D = alpha_t*normv*Id+(alpha_L-alpha_t)*outer(velocity,velocity)/normv

#========================
#  Volumetric source
#========================
f_A = Function(W.sub(0))
f_B = Function(W.sub(1))
f_A.project(Expression(("x[1] <= 0.75 && x[1] >= 0.70 && x[0] >= 0.2 && x[0] <= 0.25 ? 0.0 : 0.0")))
f_B.project(Expression(("x[1] <= 0.30 && x[1] >= 0.25 && x[0] >= 0.2 && x[0] <= 0.25 ? 0.0 : 0.0")))

#========================
#  Time-stepping
#========================
dt = 0.05
T = 2
t = 0

#========================
#  Solver parameters
#========================
OS_parameters = { 
  'ksp_type': 'cg',
  'pc_type':'hypre',
  'pc_hypre_type':'boomeramg'
}
R_parameters = { 
  'ksp_type': 'gmres',
  'pc_type':'bjacobi'
}

#========================
#  Output files
#========================
outFile_velocity = File("2D_plume_ADR_ex1/vel.pvd")
outFile_A = File("2D_plume_ADR_ex1/C_A.pvd")
outFile_B = File("2D_plume_ADR_ex1/C_B.pvd")
outFile_C = File("2D_plume_ADR_ex1/C_C.pvd")
outFile_psiA = File("2D_plume_ADR_ex1/psi_A.pvd")
outFile_psiB = File("2D_plume_ADR_ex1/psi_B.pvd")

#========================
#  Weak form
#========================
# Advection weak form
a_A_psiA = (v_A + dt*dot(velocity,grad(v_A)))*(u_A + dt*dot(velocity,grad(u_A)))*dx
a_A_psiB = (v_B + dt*dot(velocity,grad(v_B)))*(u_B + dt*dot(velocity,grad(u_B)))*dx
L_A_psiA = (v_A + dt*dot(velocity,grad(v_A)))*(u0_A + dt*f_A)*dx
L_A_psiB = (v_B + dt*dot(velocity,grad(v_B)))*(u0_B + dt*f_B)*dx
a_A = a_A_psiA + a_A_psiB
L_A = L_A_psiA + L_A_psiB
# Diffusion weak form
a_D_psiA = v_A*u_A*dx + dt*dot(grad(v_A),D*grad(u_A))*dx
a_D_psiB = v_B*u_B*dx + dt*dot(grad(v_B),D*grad(u_B))*dx
L_D_psiA = v_A*u0_A*dx
L_D_psiB = v_B*u0_B*dx
a_D = a_D_psiA + a_D_psiB
L_D = L_D_psiA + L_D_psiB

#========================
#  Reactions
#========================
W_R = Q*Q*Q
dc_A,dc_B,dc_C = TrialFunctions(W_R)
q_A,q_B,q_C = TestFunctions(W_R)
c = Function(W_R)
c_A,c_B,c_C = c.split()

k1 = Constant(1.0)
species_kernel = op2.Kernel("""
  
  // Solve reactions at each nodee
  void solve_system(const double psi_A, const double psi_B, const double k1,
    double *c_A, double *c_B, double *c_C) {
    *c_A = (sqrt(pow(psi_A*k1 - psi_B*k1 - 1,2) + 4*psi_A*k1) + psi_A*k1 - psi_B*k1 - 1)/(2*k1);
    *c_B = (sqrt(pow(psi_A*k1 - psi_B*k1 - 1,2) + 4*psi_A*k1) - psi_A*k1 + psi_B*k1     - 1)/(2*k1);
    *c_C = (sqrt(pow(psi_A*k1 - psi_B*k1 - 1,2) + 4*psi_A*k1) + psi_A*k1 + psi_B*k1     + 1)/(2*k1);
  }

  // Iterate through the node
  void solve_kernel(const double *psi_A, const double *psi_B, const double *k1,
    double *c_A, double *c_B, double *c_C) {
    for (int i = 0; i < Nvals; i++) {
      solve_system(psi_A[i],psi_B[i],k1[0],&(c_A[i]),&(c_B[i]),&(c_C[i]));
    }
  }

""","solve_kernel")

#========================
#  Dirichlet BCs
#========================
bc1 = DirichletBC(W.sub(0), Expression(("x[1] <= 0.5 ? 1.0 : 0.0")), (1))
bc2 = DirichletBC(W.sub(1), Expression(("x[1] > 0.5 ? 1.0 : 0.0")), (1))
bc3 = DirichletBC(W.sub(0), 0.0, (2,3,4))
bc4 = DirichletBC(W.sub(1), 0.0, (2,3,4))
bc_A = [bc1,bc2]
bc_D = [bc1,bc2,bc3,bc4]

#========================
#  Initial condition
#========================
u0.assign(0)
c.assign(0)

#========================
#  Optimization 
#========================
lb = Function(W)
ub = Function(W)
lb.assign(0)
ub.assign(1)
tao_A = PETSc.TAO().create(MPI.COMM_WORLD)
tao_D = PETSc.TAO().create(MPI.COMM_WORLD)
with lb.dat.vec_ro as lb_vec, ub.dat.vec as ub_vec:
  tao_A.setVariableBounds(lb_vec,ub_vec)
  tao_D.setVariableBounds(lb_vec,ub_vec)

g = Function(W)
def ObjGrad_A(tao, petsc_x, petsc_g):
  with b_A.dat.vec as b_vec:
    A_A.M.handle.mult(petsc_x,petsc_g)
    xtHx = petsc_x.dot(petsc_g)
    xtf = petsc_x.dot(b_vec)
    petsc_g.axpy(-1.0,b_vec)
    return 0.5*xtHx - xtf

# Objective function
def Obj_D(tao, petsc_x):
  with b_D.dat.vec as b_vec, g.dat.vec as g_vec:
    A_D_nobc.M.handle.mult(petsc_x,g_vec)
    xtHx = petsc_x.dot(g_vec)
    xtf = petsc_x.dot(b_vec)
  return 0.5*xtHx - xtf

# Gradient routine
def Grad_D(tao, petsc_x, petsc_g):
  with b_D.dat.vec as b_vec:
    A_D_nobc.M.handle.mult(petsc_x,petsc_g)
    petsc_g.axpy(-1.0,b_vec)

tao_D.setObjective(Obj_D)
tao_D.setGradient(Grad_D)
tao_A.setObjectiveGradient(ObjGrad_A)
tao_A.setType(PETSc.TAO.Type.BLMVM)
tao_D.setType(PETSc.TAO.Type.BLMVM)
tao_A.setTolerances(1e-8,1e-7)
tao_D.setTolerances(1e-8,1e-7)

#========================
#  AD linear operators
#========================
initialTime = time.time()
if opt_D == 1:
  A_D = assemble(a_D,bcs=bc_D)
else:
  A_D = assemble(a_D,bcs=bc_D)
if opt_A == 1:
  A_A = assemble(a_A)
else:
  A_A = assemble(a_A,bcs=bc_A)
solver_A = LinearSolver(A_A,options_prefix="A_",solver_parameters=OS_parameters)
solver_D = LinearSolver(A_D,options_prefix="D_",solver_parameters=OS_parameters)

#=======================
#  AD solver function
#=======================
def ADsolve(Operator,u0_O):
  if Operator == 'A':
    if opt_A == 0:
      solver_A.solve(u0_O,b_A)
    else:
      with u0_O.dat.vec as u0_vec:
        tao_A.solve(u0_vec)
  if Operator == 'D':
    if opt_D == 0:
      solver_D.solve(u0_O,b_D)
    else:
      with u0_O.dat.vec as u0_vec:
        tao_D.solve(u0_vec)
  return u0_O

#=======================
#  Time-stepping
#=======================
A_D_nobc = assemble(a_D)
while t <= T:
  t += dt
  if rank == 0:
    print '\tTime = %f' % t
  
  # Advection-diffusion
  b_A = assemble(L_A,bcs=bc_A)
  u0 = ADsolve('A',u0)
  b_D = assemble(L_D,bcs=bc_D)
  b_D_nobc = assemble(L_D)
  u0 = ADsolve('D',u0)
  solA,solB = u0.split()
  u0_A.assign(solA)
  u0_B.assign(solB)
  
  # Output total concentrations
  outFile_psiA << u0_A
  outFile_psiB << u0_B

  # Solve reactions 
  op2.par_loop(species_kernel,Q.dof_dset.set,solA.dat(op2.READ),solB.dat(op2.READ),k1.dat(op2.READ),c_A.dat(op2.WRITE),c_B.dat(op2.WRITE),c_C.dat(op2.WRITE))

  # Output geochemical concentrations
  outFile_A << c_A
  outFile_B << c_B
  outFile_C << c_C

#=======================
#  End simulation
#=======================
getTime = time.time()-initialTime
if rank == 0:
  print '\tWall-clock time: %.3e seconds' % getTime
