#================================================
#
#  Advective-diffusive-reactive 2D boundary 
#  By: Justin Chang
#
#  2A + B <=> 2C
#
#  Run as:
#    python 2D_OS_boundary_ex3.py <xseed> <yseed> <opt_A> <opt_D> <VI>
#
#  <x/yseed> = number of elements in each direction
#  <opt_A/D> = standard solve (0) or optimize (1) the A/D operator(s)
#  <VI> = Use Newton (0) or VI (1) solver
#
#================================================
#printat = {0}
printat = {1,10,100}
alpha_penalty = 40
gamma_penalty = 8
dt = 0.01                  # Global time-step
T = 1.00                   # Maximum time
opt_lb_psiA = 0.0          # Optimization lower-bound psiA
opt_up_psiA = 1.0          # Optimization upper-bound psiA
opt_lb_psiB = 0.0          # Optimization lower-bound psiB
opt_up_psiB = 1.0          # Optimization upper-bound psiB
problem = 3                # Problem ID
keq = 2.0                  # Equilibrium constant
aD = 1e-6                  # Molecular diffusion
at = 1e-4                  # Tranverse dispersivity
aL = 1                     # Longitudinal dispersivity
tort = 1                   # Tortuosity
TOL = 1e-7                # Tolerance

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np
import time
import sys
import math
rank = PETSc.COMM_WORLD.getRank()
parameters["matnest"] = False
op2.init(log_level='WARNING')
xseed = int(sys.argv[1])
yseed = int(sys.argv[2])
opt_A = int(sys.argv[3])
opt_D = int(sys.argv[4])
opt_VI = int(sys.argv[5])

#========================
#  Create mesh 
#========================
mesh = RectangleMesh(xseed,yseed,2,1,quadrilateral=True)
V = VectorFunctionSpace(mesh, "DG", 1)
Q = FunctionSpace(mesh, "DG", 1)
W = Q*Q
u_A,u_B = TrialFunctions(W)
v_A,v_B = TestFunctions(W)
u0 = Function(W)
u0_A,u0_B = u0.split()
bcIn = Function(W)
bcInA,bcInB = bcIn.split()

#========================
#  Advection velocity
#========================
velocity = interpolate(Expression(("1 + 0.08*pi*cos(4*pi*x[0]*0.5 - 0.5*pi)*cos(pi*x[1]) + 0.1*pi*cos(5*pi*x[0]*0.5 - 0.5*pi)*cos(5*pi*x[1]) + 0.1*pi*cos(10*pi*x[0]*0.5 - 0.5*pi)*cos(10*pi*x[1])","0.16*pi*sin(2*pi*x[0] - 0.5*pi)*sin(pi*x[1]) + 0.05*pi*sin(5*pi*x[0]*0.5 - 0.5*pi)*sin(5*pi*x[1]) + 0.05*pi*sin(5*pi*x[0] - 0.5*pi)*sin(10*pi*x[1])")),V)

#========================
#  Supplementary files
#========================
execfile('database/equilibrium_rxns3.py')

#========================
#  Dispersion tensor
#========================
normv = sqrt(dot(velocity,velocity))
Id = Identity(mesh.ufl_cell().geometric_dimension())
alpha_t = Constant(at)
alpha_L = Constant(aL)
alpha_D = Constant(aD*tort)
D = (alpha_D + alpha_t*normv)*Id+(alpha_L-alpha_t)*outer(velocity,velocity)/normv

#========================
#  Volumetric source
#========================
f_A = Function(Q)
f_B = Function(Q)
f_A_factor = Constant(0.5)
f_D_factor = Constant(0.5)
f_A.interpolate(Expression("0.0"))
f_B.interpolate(Expression("0.0"))

#========================
#  Time-stepping
#========================
tscale_A = 1.0
tscale_D = 1.0
dt_A = Constant(dt*tscale_A)
dt_D = Constant(dt*tscale_D)
t = 0
tlevel = 0

#========================
#  Solver parameters
#========================
OS_parameters = {
  'ksp_type': 'gmres',
  'pc_type': 'bjacobi'
  #'pc_type': 'hypre',
  #'pc_hypre_type': 'boomeramg'
}

#========================
#  Output files
#========================
outFile_velocity = File('Figures/' + __file__.rsplit('.',1)[0]+"/vel.pvd")
outFile_velocity.write(velocity)
outFile_A = File('Figures/' + __file__.rsplit('.',1)[0] + '/cA' + str(opt_A) + str(opt_D) + str(opt_VI) + '.pvd' )
outFile_B = File('Figures/' + __file__.rsplit('.',1)[0] + '/cB' + str(opt_A) + str(opt_D) + str(opt_VI) + '.pvd' )
outFile_C = File('Figures/' + __file__.rsplit('.',1)[0] + '/cC' + str(opt_A) + str(opt_D) + str(opt_VI) + '.pvd' )
outFile_psiA = File('Figures/' + __file__.rsplit('.',1)[0] + '/psiA' + str(opt_A) + str(opt_D) + str(opt_VI) + '.pvd' )
outFile_psiB = File('Figures/' + __file__.rsplit('.',1)[0] + '/psiB' + str(opt_A) + str(opt_D) + str(opt_VI) + '.pvd' )

#========================
#  Weak form
#========================
n = FacetNormal(mesh)
h = Constant(1/yseed)
#h = CellSize(mesh)
alpha = Constant(alpha_penalty)
gamma = Constant(gamma_penalty)
h_avg = Constant(1/yseed) #0.5*(h('+')+h('-'))
bcInA.interpolate(Expression("x[1] < 0.5 ? 1.0 : 0.0"))
bcInB.interpolate(Expression("x[1] > 0.5 ? 1.0 : 0.0"))
vn = 0.5*(dot(velocity,n) + abs(dot(velocity,n)))

# Advection weak form
a_A_psiA = v_A*u_A*dx - dt_A*(u_A*dot(velocity,grad(v_A))*dx \
    - dot(jump(v_A),vn('+')*u_A('+')-vn('-')*u_A('-'))*dS - dot(v_A,vn*u_A)*ds(2))
a_A_psiB = v_B*u_B*dx - dt_A*(u_B*dot(velocity,grad(v_B))*dx \
    - dot(jump(v_B),vn('+')*u_B('+')-vn('-')*u_B('-'))*dS - dot(v_B,vn*u_B)*ds(2))
L_A_psiA = v_A*(u0_A + dt_A*f_A_factor*f_A)*dx \
    - dt_A*bcInA*v_A*dot(velocity,n)*ds(1)
L_A_psiB = v_B*(u0_B + dt_A*f_A_factor*f_B)*dx \
    - dt_A*bcInB*v_B*dot(velocity,n)*ds(1)
a_A = a_A_psiA + a_A_psiB
L_A = L_A_psiA + L_A_psiB


# Diffusion weak form
a_D_psiA = v_A*u_A*dx + dt_D*(dot(grad(v_A),D*grad(u_A))*dx \
    - dot(jump(v_A,n),D*avg(grad(u_A)))*dS - dot(D*avg(grad(v_A)),jump(u_A,n))*dS \
    + avg(alpha/h_avg)*dot(jump(v_A,n),jump(u_A,n))*dS) #+ gamma/h*v_A*u_A*ds \
    #- dot(v_A*n,grad(u_A))*ds - dot(grad(v_A),u_A*n)*ds)
a_D_psiB = v_B*u_B*dx + dt_D*(dot(grad(v_B),D*grad(u_B))*dx \
    - dot(jump(v_B,n),D*avg(grad(u_B)))*dS - dot(D*avg(grad(v_B)),jump(u_B,n))*dS \
    + avg(alpha/h_avg)*dot(jump(v_B,n),jump(u_B,n))*dS) #+ gamma/h*v_B*u_B*ds \
    #- dot(v_B*n,grad(u_B))*ds - dot(grad(v_B),u_B*n)*ds)
L_D_psiA = v_A*u0_A*dx + dt_D*(v_A*f_D_factor*f_A)*dx #\
    #+ bcInA*(v_A*gamma/h - dot(D*grad(v_A),n))*ds(1)) #\
    #+ bcOutA*(v_A*gamma/h - dot(D*grad(v_A),n))*ds(2))
L_D_psiB = v_B*u0_B*dx + dt_D*(v_B*f_D_factor*f_B)*dx #\
    #+ bcInB*(v_B*gamma/h - dot(D*grad(v_B),n))*ds(1)) #\
    #+ bc_B_out*(v_B*gamma/h - dot(D*grad(v_B),n))*ds(2))
a_D = a_D_psiA + a_D_psiB
L_D = L_D_psiA + L_D_psiB


#========================
#  Reactions
#========================
W_R = Q*Q*Q
c = Function(W_R)
c_A,c_B,c_C = c.split()
#k1 = Constant(keq)
k1 = Function(Q)
k1.assign(keq)

#========================
#  Equilibrium reactions
#========================
W_R = Q*Q*Q
c = Function(W_R)
cv_A,cv_B,cv_C = TestFunctions(W_R)
c_A,c_B,c_C = c.split()
k1 = Constant(keq)

#========================
#  Dirichlet BCs
#========================
bc1 = DirichletBC(W.sub(0), Expression(("x[1] > 0.0 && x[1] <= 0.5 ? 1.0 : 0")), (1), method="geometric")
bc2 = DirichletBC(W.sub(1), Expression(("x[1] > 0.5 && x[1] < 1.0 ? 1.0 : 0.0")), (1), method="geometric")
bc3 = DirichletBC(W.sub(0), 0.0, (2,3,4), method="geometric")
bc4 = DirichletBC(W.sub(1), 0.0, (2,3,4), method="geometric")
bc_A = [bc1,bc2]
bc_D = [bc1,bc2,bc3,bc4]

#========================
#  Initial condition
#========================
u0.assign(0)
c.assign(0)

#========================
#  Optimization 
#========================
opt_flag = 'D'
lb = Function(W)
lb_A,lb_B = lb.split()
ub = Function(W)
ub_A,ub_B = ub.split()
lb_A.assign(opt_lb_psiA)
lb_B.assign(opt_lb_psiB)
ub_A.assign(opt_up_psiA)
ub_B.assign(opt_up_psiB)
tao_A = PETSc.TAO().create(PETSc.COMM_WORLD)
tao_D = PETSc.TAO().create(PETSc.COMM_WORLD)
with lb.dat.vec_ro as lb_vec, ub.dat.vec as ub_vec:
  tao_A.setVariableBounds(lb_vec,ub_vec)
  tao_D.setVariableBounds(lb_vec,ub_vec)

def Hessian_O(tao, petsc_x, petsc_H, petsc_HP):
  pass

def ObjGrad_A(tao, petsc_x, petsc_g):
  with b.dat.vec as b_vec:
    A_A.M.handle.mult(petsc_x,petsc_g)
    xtHx = petsc_x.dot(petsc_g)
    xtf = petsc_x.dot(b_vec)
    petsc_g.axpy(-1.0,b_vec)
    return 0.5*xtHx - xtf

def ObjGrad_D(tao, petsc_x, petsc_g):
  with b.dat.vec as b_vec:
    A_D.M.handle.mult(petsc_x,petsc_g)
    xtHx = petsc_x.dot(petsc_g)
    xtf = petsc_x.dot(b_vec)
    petsc_g.axpy(-1.0,b_vec)
    return 0.5*xtHx - xtf 

tao_D.setObjectiveGradient(ObjGrad_D)
tao_D.setType(PETSc.TAO.Type.BLMVM)
tao_D.setTolerances(TOL,TOL)
tao_A.setObjectiveGradient(ObjGrad_A)
tao_A.setType(PETSc.TAO.Type.BLMVM)
tao_A.setTolerances(TOL,TOL)

#========================
#  AD linear operators
#========================
initialTime = time.time()
A_D = assemble(a_D,bcs=bc_D)
A_A = assemble(a_A) #,bcs=bc_A)
tao_D.setHessian(Hessian_O,A_D.M.handle)
tao_A.setHessian(Hessian_O,A_A.M.handle)  
solver_A = LinearSolver(A_A,options_prefix="A_",solver_parameters=OS_parameters)
solver_D = LinearSolver(A_D,options_prefix="D_",solver_parameters=OS_parameters)

#=======================
#  RHS for optimization
#=======================
tmp = Function(W)
for bc in bc_D:
  bc.apply(tmp)
rhs_D = assemble(action(a_D,tmp))
for bc in bc_A:
  bc.apply(tmp)
rhs_A = assemble(action(a_A,tmp))

#=======================
#  AD solver functions
#=======================
tmp = Function(W)
def ADassemble(A_O,a_O,L_O,bc_O,opt_O,rhs_O):
  #assemble(a_O,bcs=bc_O,tensor=A_O)
  if opt_O == 0:
    b = assemble(L_O)
    for bc in bc_O:
      bc.apply(b)
  else:
    b = assemble(L_O)
    b -= rhs_O
    for bc in bc_O:
      bc.apply(b)
  return b

def ADsolve(solver_O,tao_O,opt_O,u0_O):
  if opt_O == 0:
    solver_O.solve(u0_O,b)
  else:
    with u0_O.dat.vec as u0_vec:
      tao_O.solve(u0_vec)
  return u0_O

#=======================
#  Time-stepping
#=======================
while (t < T):
  t += dt
  tlevel += 1
  if rank == 0:
    print '\tTime = %0.3f' % t
  
  # Advection-diffusion
  #with PETSc.Log.Stage("Advection"):
  #  b = ADassemble(A_A,a_A,L_A,None,opt_A,rhs_A)
  #  u0 = ADsolve(solver_A,tao_A,opt_A,u0)
  with PETSc.Log.Stage("Diffusion"):
    b = ADassemble(A_D,a_D,L_D,bc_D,opt_D,rhs_D)
    u0 = ADsolve(solver_D,tao_D,opt_D,u0)
  
  # Reactions
  for i in printat:
    if i == tlevel or i == 0:      
      #with u0_A.dat.vec as psiA_vec, u0_B.dat.vec as psiB_vec:
      #  psiA_vec.chop(TOL)
      #  psiB_vec.chop(TOL)
      op2.par_loop(reactions3_kernel,Q.dof_dset.set,u0_A.dat(op2.READ),u0_B.dat(op2.READ),c_A.dat(op2.WRITE),c_B.dat(op2.WRITE),c_C.dat(op2.WRITE),k1.dat(op2.READ),ctx(op2.READ))
      with c_A.dat.vec as cA_vec, c_B.dat.vec as cB_vec, c_C.dat.vec as cC_vec:
        cA_vec.chop(1e-9)
        cB_vec.chop(1e-9)
        cC_vec.chop(1e-9)
      outFile_psiA << u0_A
      outFile_psiB << u0_B
      outFile_A << c_A
      outFile_B << c_B
      outFile_C << c_C

#=======================
#  End simulation
#=======================
getTime = time.time()-initialTime
if rank == 0:
  print '\tWall-clock time: %.3e seconds' % getTime
