#================================================
#
#  Advective-diffusive-reactive 1D example
#  By: Justin Chang
#
#  ex1: psi form 1 - ADR
#
#================================================
printat = {0}
#printat = {1,10,50,100}
problem = 1                # Problem ID
vel = 1                    # Velocity
diffusivity = float(1/150) # Diffusivity
dt = 0.01                  # Global time-step
T = 1.00                   # Maximum time
tstep_A = 0.5*dt           # Advection time-step
tstep_D = 0.5*dt           # Diffusion time-step
keq = 0.1                  # Equilibrium constant
numKeq = 1                 # Number of mass action eqns
TOL = 1e-11                # Tolerance

import sys
seed = int(sys.argv[1])
DAD = int(sys.argv[2])
opt_VI = int(sys.argv[3])
#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np
import time
rank = PETSc.COMM_WORLD.getRank()
parameters["matnest"] = False

#========================
#  Create mesh 
#========================
mesh = UnitIntervalMesh(seed)
Q = FunctionSpace(mesh, "CG", 1)
W = Q*Q
u_A,u_B = TrialFunctions(W)
v_A,v_B = TestFunctions(W)
u0 = Function(W)
u0_A,u0_B = u0.split()

#========================
#  Advection velocity
#========================
velocity = as_vector((Constant(vel),))

#========================
#  Dispersion tensor
#========================
D = Constant(diffusivity)

#========================
#  Volumetric source
#========================
if DAD == 1:
  f_DAD = 0.5
  tstep_D = tstep_D * 0.5
else:
  f_DAD = 1.0
f_A_A = Constant(0.5)
f_B_A = Constant(0.0)
f_A_D = Constant(0.5*f_DAD)
f_B_D = Constant(0.0)

#========================
#  Time-stepping
#========================
t = 0
dt_A = Constant(tstep_A)
dt_D = Constant(tstep_D)
tlevel = 0

#========================
#  Solver parameters
#========================
OS_parameters = { 
  'ksp_type': 'cg',
  'pc_type': 'gamg'
}

#========================
#  Output files
#========================
outFile_A = File(__file__.rsplit('.',1)[0] + '/c_A' + str(DAD) + '.pvd' )
outFile_B = File(__file__.rsplit('.',1)[0] + '/c_B' + str(DAD) + '.pvd' )
outFile_C = File(__file__.rsplit('.',1)[0] + '/c_C' + str(DAD) + '.pvd' )
outFile_psiA = File(__file__.rsplit('.',1)[0] + '/psi_A' + str(DAD) + '.pvd' )
outFile_psiB = File(__file__.rsplit('.',1)[0] + '/psi_B' + str(DAD) + '.pvd' )

#========================
#  Weak form
#========================
# Advection weak form
a_A_psiA = (v_A + dt_A*dot(velocity,grad(v_A)))*(u_A + dt_A*dot(velocity,grad(u_A)))*dx
a_A_psiB = (v_B + dt_A*dot(velocity,grad(v_B)))*(u_B + dt_A*dot(velocity,grad(u_B)))*dx
L_A_psiA = (v_A + dt_A*dot(velocity,grad(v_A)))*(u0_A + dt_A*f_A_A)*dx
L_A_psiB = (v_B + dt_A*dot(velocity,grad(v_B)))*(u0_B + dt_A*f_B_A)*dx
a_A = a_A_psiA + a_A_psiB
L_A = L_A_psiA + L_A_psiB
# Diffusion weak form
a_D_psiA = v_A*u_A*dx + dt_D*dot(grad(v_A),D*grad(u_A))*dx
a_D_psiB = v_B*u_B*dx + dt_D*dot(grad(v_B),D*grad(u_B))*dx
L_D_psiA = v_A*u0_A*dx + v_A*dt_D*f_A_D*dx
L_D_psiB = v_B*u0_B*dx + v_B*dt_D*f_B_D*dx
a_D = a_D_psiA + a_D_psiB
L_D = L_D_psiA + L_D_psiB

#========================
#  Reactions
#========================
W_R = Q*Q*Q
c = Function(W_R)
c_A,c_B,c_C = c.split()
k1 = Constant(keq)
execfile('reactions_kernel.py')

#========================
#  Dirichlet BCs
#========================
bc1 = DirichletBC(W.sub(0), 0.0, (1))
bc2 = DirichletBC(W.sub(1), 1.0, (1))
bc3 = DirichletBC(W.sub(0), 0.0, (2))
bc4 = DirichletBC(W.sub(1), 0.0, (2))
bc_A = [bc1,bc2]
bc_D = [bc1,bc2,bc3,bc4]

#========================
#  Initial condition
#========================
u0.assign(0)
c.assign(0)

#========================
#  AD linear operators
#========================
initialTime = time.time()
A_D = assemble(a_D,bcs=bc_D)
A_A = assemble(a_A,bcs=bc_A)
solver_A = LinearSolver(A_A,options_prefix="A_",solver_parameters=OS_parameters)
solver_D = LinearSolver(A_D,options_prefix="D_",solver_parameters=OS_parameters)

#=======================
#  RHS for optimization
#=======================
tmp = Function(W)
for bc in bc_D:
  bc.apply(tmp)
rhs_D = assemble(action(a_D,tmp))
for bc in bc_A:
  bc.apply(tmp)
rhs_A = assemble(action(a_A,tmp))

#=======================
#  AD solver functions
#=======================
tmp = Function(W)
def ADassemble(A_O,a_O,L_O,bc_O,rhs_O):
  assemble(a_O,bcs=bc_O,tensor=A_O)
  b = assemble(L_O)
  b -= rhs_O
  for bc in bc_O:
    bc.apply(b)
  return b

def ADsolve(solver_O,u0_O):
  solver_O.solve(u0_O,b)
  return u0_O

#=======================
#  Time-stepping
#=======================
while (t < T):
  t += dt
  tlevel += 1
  if rank == 0:
    print '\tTime = %f' % t
  
  # Advection-diffusion
  if DAD == 1:
    with PETSc.Log.Stage("Diffusion"):
      b = ADassemble(A_D,a_D,L_D,bc_D,rhs_D)
      u0 = ADsolve(solver_D,u0)
  with PETSc.Log.Stage("Advection"):
    b = ADassemble(A_A,a_A,L_A,bc_A,rhs_A)
    u0 = ADsolve(solver_A,u0)
  with PETSc.Log.Stage("Diffusion"):
    b = ADassemble(A_D,a_D,L_D,bc_D,rhs_D)
    u0 = ADsolve(solver_D,u0)
  
  # Reactions
  for i in printat:
    if i == tlevel or i == 0:      
      np.savetxt(__file__.rsplit('.',1)[0]+"/out_psiA_"+str(tlevel),u0_A.vector().array().reshape(seed+1,),fmt="%f")
      np.savetxt(__file__.rsplit('.',1)[0]+"/out_psiB_"+str(tlevel),u0_B.vector().array().reshape(seed+1,),fmt="%f")
      with u0_A.dat.vec as psiA_vec, u0_B.dat.vec as psiB_vec:
        psiA_vec.chop(TOL)
        psiB_vec.chop(TOL)
      op2.par_loop(reactions_kernel,Q.dof_dset,u0.dat(op2.READ),c.dat(op2.WRITE),
          k1.dat(op2.READ))
      #with c_A.dat.vec as cA_vec, c_B.dat.vec as cB_vec, c_C.dat.vec as cC_vec:
      #  cA_vec.chop(TOL)
      #  cB_vec.chop(TOL)
      #  cC_vec.chop(TOL)
      outFile_psiA << u0_A
      outFile_psiB << u0_B
      outFile_A << c_A
      outFile_B << c_B
      outFile_C << c_C

#=======================
#  End simulation
#=======================
getTime = time.time()-initialTime
if rank == 0:
  print '\tWall-clock time: %.3e seconds' % getTime
