#================================================
#
#  Advective-diffusive-reactive system example
#  By: Justin Chang
#
#================================================
#import sys
#xseed = int(sys.argv[1])
#yseed = int(sys.argv[2])
xseed = 100
yseed = 50
OS = 1
optimize_D = 1

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
from mpi4py import MPI
from petsc4py import PETSc
from scipy.optimize import fsolve
import math
import numpy as np
import time
comm = MPI.COMM_WORLD
rank = comm.Get_rank()

#========================
#  Create mesh 
#========================
mesh = RectangleMesh(xseed,yseed,2,1,quadrilateral=True)
V = VectorFunctionSpace(mesh, "CG", 1)
Q = FunctionSpace(mesh, "CG", 1)
W = Q*Q
u_A,u_B = TrialFunctions(W)
v_A,v_B = TestFunctions(W)
u0 = Function(W)
u0_A,u0_B = u0.split()

#========================
#  Advection velocity
#========================
velocity = Function(V)
velocity.project(Expression(("1 + 0.08*pi*cos(4*pi*x[0]*0.5 - 0.5*pi)*cos(pi*x[1]) + 0.1*pi*cos(5*pi*x[0]*0.5 - 0.5*pi)*cos(5*pi*x[1]) + 0.1*pi*cos(10*pi*x[0]*0.5 - 0.5*pi)*cos(10*pi*x[1])","0.16*pi*sin(2*pi*x[0] - 0.5*pi)*sin(pi*x[1]) + 0.05*pi*sin(5*pi*x[0]*0.5 - 0.5*pi)*sin(5*pi*x[1]) + 0.05*pi*sin(5*pi*x[0] - 0.5*pi)*sin(10*pi*x[1])")))

#========================
#  Diffusivity tensor
#========================
normv = inner(velocity,velocity)
Id = Identity(mesh.ufl_cell().geometric_dimension())
alpha_t = Constant(0.0001)
alpha_L = Constant(1)
D = alpha_t*normv*Id+(alpha_L-alpha_t)*outer(velocity,velocity)/normv

#========================
#  Volumetric source
#========================
f_A = Function(W.sub(0))
f_B = Function(W.sub(1))
f_A.project(Expression(("x[1] <= 0.75 && x[1] >= 0.70 && x[0] >= 0.2 && x[0] <= 0.25 ? 0.0 : 0.0")))
f_B.project(Expression(("x[1] <= 0.30 && x[1] >= 0.25 && x[0] >= 0.2 && x[0] <= 0.25 ? 0.0 : 0.0")))

#========================
#  Time-stepping
#========================
dt = 0.05
T = 2
t = 0

#========================
#  Solver parameters
#========================
SUPG_parameters = { 
  'ksp_type': 'gmres',
  'pc_type':'bjacobi'
}
OS_parameters = { 
  'ksp_type': 'cg',
  'pc_type':'bjacobi'
}
R_parameters = { 
  'ksp_type': 'gmres',
  'pc_type':'bjacobi'
}

#========================
#  Output files
#========================
outFile_velocity = File("2D_plume_ADR_ex1/vel.pvd")
outFile_velocity << velocity
outFile_A = File("2D_plume_ADR_ex1/C_A.pvd")
outFile_B = File("2D_plume_ADR_ex1/C_B.pvd")
outFile_C = File("2D_plume_ADR_ex1/C_C.pvd")
outFile_psiA = File("2D_plume_ADR_ex1/psi_A.pvd")
outFile_psiB = File("2D_plume_ADR_ex1/psi_B.pvd")

#========================
#  AD - SUPG
#========================
if OS == 0:
  # SUPG stabilization
  h = CellSize(mesh)
  vnorm = sqrt(dot(velocity,velocity))
  Pe_A = h/(2*vnorm)*dot(velocity,grad(v_A))
  Pe_B = h/(2*vnorm)*dot(velocity,grad(v_B))
  a_A_r = Pe_A*(u_A + dt*(dot(velocity,grad(u_A)) - div(D*grad(u_A))))*dx
  a_B_r = Pe_B*(u_B + dt*(dot(velocity,grad(u_B)) - div(D*grad(u_B))))*dx
  L_A_r = Pe_A*(u0_A + dt*f_A)*dx
  L_B_r = Pe_B*(u0_B + dt*f_B)*dx
  # SUPG weak form
  a_SUPG = a_A_r + a_B_r + v_A*u_A*dx + dt*(v_A*dot(velocity,grad(u_A))*dx + dot(grad(v_A),D*grad(u_A))*dx) + dt*(v_B*dot(velocity,grad(u_B))*dx + dot(grad(v_B),D*grad(u_B))*dx)
  L_SUPG = L_A_r + L_B_r + v_A*u0_A*dx + dt*v_A*f_A*dx + v_B*u0_B*dx + dt*v_B*f_B*dx

#========================
#  AD - Operator split
#========================
else:
  # Advection weak form
  a_A_psiA = (v_A + dt*dot(velocity,grad(v_A)))*(u_A + dt*dot(velocity,grad(u_A)))*dx
  a_A_psiB = (v_B + dt*dot(velocity,grad(v_B)))*(u_B + dt*dot(velocity,grad(u_B)))*dx
  L_A_psiA = (v_A + dt*dot(velocity,grad(v_A)))*(u0_A + dt*f_A)*dx
  L_A_psiB = (v_B + dt*dot(velocity,grad(v_B)))*(u0_B + dt*f_B)*dx
  a_A = a_A_psiA + a_A_psiB
  L_A = L_A_psiA + L_A_psiB
  # Diffusion weak form
  a_D_psiA = v_A*u_A*dx + dt*dot(grad(v_A),D*grad(u_A))*dx
  a_D_psiB = v_B*u_B*dx + dt*dot(grad(v_B),D*grad(u_B))*dx
  L_D_psiA = v_A*u0_A*dx
  L_D_psiB = v_B*u0_B*dx
  a_D = a_D_psiA + a_D_psiB
  L_D = L_D_psiA + L_D_psiB

#========================
#  Reactions - check this
#========================
Q_R = FunctionSpace(mesh,"CG",1)
W_R = Q_R*Q_R*Q_R
dc_A,dc_B,dc_C = TrialFunctions(W_R)
q_A,q_B,q_C = TestFunctions(W_R)
c = Function(W_R)
c_A,c_B,c_C = split(c)

c_lb = Function(W_R)
c_ub = Function(W_R)
c_lb.assign(0)
c_ub.assign(1)
# Residual and Jacobian form
F = (q_A*(c_A + c_C - u0_A) + q_B*(c_B + c_C - u0_B) + q_C*(c_A*c_B - c_C))*dx
J = (q_A*(dc_A + dc_C) + q_B*(dc_B + dc_C) + q_C*(c_B*dc_A + c_A*dc_B - dc_C))*dx
problem_R = NonlinearVariationalProblem(F,c,J=J,nest=False)
solver_R = NonlinearVariationalSolver(problem_R,solver_parameters=R_parameters,options_prefix="R_")
c_A,c_B,c_C = c.split()

#========================
#  Dirichlet BCs
#========================
bc1 = DirichletBC(W.sub(0), Expression(("x[1] <= 0.5 ? 1.0 : 0.0")), (1))
bc2 = DirichletBC(W.sub(1), Expression(("x[1] > 0.5 ? 1.0 : 0.0")), (1))
bc3 = DirichletBC(W.sub(0), 0.0, (2,3,4))
bc4 = DirichletBC(W.sub(1), 0.0, (2,3,4))
if OS == 0:
  bc_SUPG = [bc1,bc2,bc3,bc4]
else:
  bc_A = [bc1,bc2]
  bc_D = [bc1,bc2,bc3,bc4]

#========================
#  Initial condition
#========================
u0.assign(0)
u0_A.assign(0)
u0_B.assign(0)
c.assign(0)
c_A.assign(0)
c_B.assign(0)
c_C.assign(0)

#========================
#  Optimization 
#========================
lb = Function(W)
ub = Function(W)
lb.assign(0)
ub.assign(1)
tao = PETSc.TAO().create(MPI.COMM_WORLD)
with lb.dat.vec as lb_vec, ub.dat.vec as ub_vec:
  tao.setVariableBounds(lb_vec,ub_vec)

#========================
#
#  min 0.5*x^T*H*x - x^T*f
#  s.t. 0 < x < 1
#
#a  A_D_xxxx and b_D_xxxx 
#    are known a prior
#    at each time level
#
#========================
def ObjGrad_D(tao, petsc_x, petsc_g):
  with b_D_nobc.dat.vec as b_D_vec:
    A_D_nobc.M.handle.mult(petsc_x,petsc_g)
    xtHx = petsc_x.dot(petsc_g)
    xtf = petsc_x.dot(b_D_vec)
    petsc_g.axpy(-1.0,b_D_vec)
    arr = petsc_g.array
    for bc in bc_D:
      nodes = bc.nodes[bc.nodes < bc.function_space().dof_dset.size]
      if isinstance(bc.function_arg, Function):
        with bc.function_arg.dat.vec_ro as bc_vec:
          arr[nodes] = bc_vec.array_r[nodes]
      else:
        arr[nodes] = float(bc.function_arg)
    return 0.5*xtHx - xtf
# Old
#objective_D = 0.5*(u0_A*u0_A*dx + dt*dot(grad(u0_A),D*grad(u0_A))*dx) - u0_A*u0_A*dx + 0.5*(u0_B*    u0_B*dx + dt*dot(grad(u0_B),D*grad(u0_B))*dx) - u0_B*u0_B*dx
#gradient_D = action(a_D,u0) - L_D

#g = Function(W)
#def objective(tao, petsc_x):
#  with u0.dat.vec as v:
#    petsc_x.copy(v)
#  hom_bcs = [homogenize(bc) for bc in bc_D]
#  objective = assemble(objective_D,bcs=hom_bcs)
#  return objective

#def gradient(tao, petsc_x, petsc_g):
#  with u0.dat.vec as v:
#    petsc_x.copy(v)
#  hom_bcs = [homogenize(bc) for bc in bc_D]
#  gr = assemble(gradient_D,bcs=hom_bcs,tensor=g)
#  with gr.dat.vec_ro as v:
#    v.copy(petsc_g)

tao.setObjectiveGradient(ObjGrad_D)
tao.setType(PETSc.TAO.Type.BLMVM)
tao.setTolerances(1e-8,1e-8)
tao.setFromOptions()

#========================
#  Scipy
#========================
cA_vec = np.zeros(len(u0_A.vector()))
cB_vec = np.zeros(len(u0_A.vector()))
cC_vec = np.zeros(len(u0_A.vector()))

#========================
#  AD linear operators
#========================
initialTime = time.time()
if OS == 0:
  A_SUPG = assemble(a_SUPG,bcs=bc_SUPG)
  solver_SUPG = LinearSolver(A_SUPG,options_prefix="SUPG_",solver_parameters=SUPG_parameters)
else:
  A_D = assemble(a_D,bcs=bc_D)
  A_A = assemble(a_A,bcs=bc_A)
  solver_A = LinearSolver(A_A,options_prefix="A_",solver_parameters=OS_parameters)
  solver_D = LinearSolver(A_D,options_prefix="D_",solver_parameters=OS_parameters)

#=======================
#  AD solver functions
#=======================
def OSsolve(solver_O,L_O,bc_O,u0):
  b = assemble(L_O,bcs=bc_O)
  solver_O.solve(u0,b)
  return u0

#=======================
#  Time-stepping
#=======================
while t < T:
  if rank == 0:
    print '\tTime = %f' % t
  # Advection-diffusion
  if OS == 0:
    u0 = OSsolve(solver_SUPG,L_SUPG,bc_SUPG,u0)
    solA,solB = u0.split()
    u0_A.assign(solA)
    u0_B.assign(solB)
  else:
    u0 = OSsolve(solver_A,L_A,bc_A,u0)
    solA,solB = u0.split()
    u0_A.assign(solA)
    u0_B.assign(solB)
    if optimize_D == 0:
      u0 = OSsolve(solver_D,L_D,bc_D,u0)
    else:
      b_D_nobc = assemble(L_D)
      A_D_nobc = assemble(a_D)
      with u0.dat.vec as u0_vec:
        tao.solve(u0_vec)
    solA,solB = u0.split()
    u0_A.assign(solA)
    u0_B.assign(solB)
  
  # Output total concentrations
  outFile_psiA << u0_A
  outFile_psiB << u0_B

  # Solve reactions 
  #u0_A.project(Expression(("sin(pi*x[0]/2)")))
  #u0_B.project(Expression(("1+cos(pi*x[1])")))
  #solver_R.solve()#lb=c_lb,ub=c_ub)
  
  # scipy
  #def equations(p,psiA,psiB):
  #  cA,cB,cC = p
  #  return (cA+cC-psiA, cB+cC-psiB, cC-cA*cB)
  #for i in range(len(u0_A.vector())):
  #  cA_i,cB_i,cC_i = fsolve(equations,(0.0,0.0,0.0),args=(u0_A.vector()[i],u0_B.vector()[i]))
  #  cA_vec[i] = cA_i
  #  cB_vec[i] = cB_i
  #  cC_vec[i] = cC_i
  #c_A.vector()[:] = cA_vec
  #c_B.vector()[:] = cB_vec
  #c_C.vector()[:] = cC_vec
  # Output geochemical concentrations
  #conc_A,conc_B,conc_C = c.split() 
  #outFile_A << c_A#conc_A
  #outFile_B << c_B#conc_B
  #outFile_C << c_C#conc_C
  t += dt

#=======================
#  End simulation
#=======================
getTime = time.time()-initialTime
if rank == 0:
  print '\tWall-clock time: %.3e seconds' % getTime
