#================================================
#
#  Advective-diffusive-reactive 1D psi form1
#  By: Justin Chang
#
#
#  Run as: 
#    python 1D_OS_analytical.py <seed> <OS> <VI>
#
#  <seed> = number of elements
#  <OS> = Operator split type
#    0 = AD
#    1 = DAD
#    2 = ADAD
#  <VI> = Use Newton (0) or VI (1) solver
#
#================================================
from __future__ import division
printat = {0}
#printat = {1,10,50,100}
problem = 1                # Problem ID
vel = 1                    # Velocity
diffusivity = 1/150        # Diffusivity
dt = 0.01                  # Global time-step
T = 2.00                   # Maximum time
keq = 1.0                  # Equilibrium constant
TOL = 1e-11                # Tolerance
import sys
seed = int(sys.argv[1])
opt_OS = int(sys.argv[2])
opt_VI = int(sys.argv[3])

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np
import time
rank = PETSc.COMM_WORLD.getRank()
parameters["matnest"] = False
op2.init(log_level='WARNING')

#========================
#  Create mesh 
#========================
mesh = UnitIntervalMesh(seed)
V = VectorFunctionSpace(mesh,"CG",1)
Q = FunctionSpace(mesh, "CG", 1)
W = Q*Q
u_A,u_B = TrialFunctions(W)
v_A,v_B = TestFunctions(W)
u0 = Function(W)
u0_A,u0_B = u0.split()

#========================
#  Time-stepping
#========================
tscale_A = 1.0
tscale_D = 1.0
if opt_OS >= 1:
  tscale_D = 0.5
if opt_OS == 2:
  tscale_A = 0.5
dt_A = Constant(dt*tscale_A)
dt_D = Constant(dt*tscale_D)
t = 0
tlevel = 0

#========================
#  Supplementary files
#========================
execfile('1D_expressions.py')
execfile('reactions3_kernel.py')

#========================
#  Advection velocity
#========================
velocity = as_vector((Constant(vel),))

#========================
#  Dispersion tensor
#========================
D = Constant(diffusivity)

#========================
#  Volumetric source
#========================
f_A = Function(Q)
f_B = Function(Q)
f_A_factor = Constant(0.5)
f_D_factor = Constant(0.5)
f_A.interpolate(fpsiA1_expression)
f_B.interpolate(fpsiB1_expression)

#========================
#  Solver parameters
#========================
OS_parameters = { 
  'ksp_type': 'cg',
  'pc_type': 'gamg'
}

#========================
#  Output files
#========================
outFile_A = File('Figures/' + __file__.rsplit('.',1)[0] + '/cA' + str(opt_OS) + str(opt_VI) + '.pvd' )
outFile_B = File('Figures/' + __file__.rsplit('.',1)[0] + '/cB' + str(opt_OS) + str(opt_VI) + '.pvd' )
outFile_C = File('Figures/' + __file__.rsplit('.',1)[0] + '/cC' + str(opt_OS) + str(opt_VI) + '.pvd' )
outFile_psiA = File('Figures/' + __file__.rsplit('.',1)[0] + '/psiA' + str(opt_OS) + str(opt_VI) + '.pvd' )
outFile_psiB = File('Figures/' + __file__.rsplit('.',1)[0] + '/psiB' + str(opt_OS) + str(opt_VI) + '.pvd' )
outFile_exact_psiA = File('Figures/' + __file__.rsplit('.',1)[0] + '/exact_psiA' + '.pvd' )
outFile_exact_psiB = File('Figures/' + __file__.rsplit('.',1)[0] + '/exact_psiB' + '.pvd' )
outFile_exact_cA = File('Figures/' + __file__.rsplit('.',1)[0] + '/exact_cA' + '.pvd' )
outFile_exact_cB = File('Figures/' + __file__.rsplit('.',1)[0] + '/exact_cB' + '.pvd' )
outFile_exact_cC = File('Figures/' + __file__.rsplit('.',1)[0] + '/exact_cC' + '.pvd' )

#========================
#  Weak form
#========================
# Advection weak form
a_A_psiA = (v_A + dt_A*dot(velocity,grad(v_A)))*(u_A + dt_A*dot(velocity,grad(u_A)))*dx
a_A_psiB = (v_B + dt_A*dot(velocity,grad(v_B)))*(u_B + dt_A*dot(velocity,grad(u_B)))*dx
L_A_psiA = (v_A + dt_A*dot(velocity,grad(v_A)))*(u0_A + dt_A*f_A_factor*f_A)*dx
L_A_psiB = (v_B + dt_A*dot(velocity,grad(v_B)))*(u0_B + dt_A*f_A_factor*f_B)*dx
a_A = a_A_psiA + a_A_psiB
L_A = L_A_psiA + L_A_psiB
# Diffusion weak form
a_D_psiA = v_A*u_A*dx + dt_D*dot(grad(v_A),D*grad(u_A))*dx
a_D_psiB = v_B*u_B*dx + dt_D*dot(grad(v_B),D*grad(u_B))*dx
L_D_psiA = v_A*u0_A*dx + v_A*dt_D*f_D_factor*f_A*dx
L_D_psiB = v_B*u0_B*dx + v_B*dt_D*f_D_factor*f_B*dx
a_D = a_D_psiA + a_D_psiB
L_D = L_D_psiA + L_D_psiB

#========================
#  Reactions
#========================
W_R = Q*Q*Q
c = Function(W_R)
c_A,c_B,c_C = c.split()
k1 = Constant(keq)

#========================
#  Exact solutions
#========================
exact_psiA = Function(Q)
exact_psiB = Function(Q)
exact_cA = Function(Q)
exact_cB = Function(Q) 
exact_cC = Function(Q)
exact_psiA.interpolate(psiA1_expression)
exact_psiB.interpolate(psiB1_expression)
exact_cA.interpolate(cA_expression)
exact_cB.interpolate(cB_expression)
exact_cC.interpolate(cC_expression)
L2_psiA = 0.0
L2_psiB = 0.0
L2_cA = 0.0
L2_cB = 0.0
L2_cC = 0.0

#========================
#  Dirichlet BCs
#========================
bc1 = DirichletBC(W.sub(0), 0.0, (1))
bc2 = DirichletBC(W.sub(1), 1.0, (1))
bc3 = DirichletBC(W.sub(0), 0.0, (2))
bc4 = DirichletBC(W.sub(1), 0.0, (2))
bc_A = [bc1,bc2]
bc_D = [bc1,bc2,bc3,bc4]

#========================
#  Initial conditions
#========================
u0_A.interpolate(psiA1_expression)
u0_B.interpolate(psiB1_expression)
c_A.interpolate(cA_expression)
c_B.interpolate(cB_expression)
c_C.interpolate(cC_expression)

#========================
#  AD linear operators
#========================
initialTime = time.time()
A_D = assemble(a_D,bcs=bc_D)
A_A = assemble(a_A,bcs=bc_A)
solver_A = LinearSolver(A_A,options_prefix="A_",solver_parameters=OS_parameters)
solver_D = LinearSolver(A_D,options_prefix="D_",solver_parameters=OS_parameters)

#=======================
#  AD solver functions
#=======================
tmp = Function(W)
def ADassemble(L_O,bc_O):
  b = assemble(L_O,bcs=bc_O)
  return b

def ADsolve(solver_O,u0_O,b):
  solver_O.solve(u0_O,b)
  return u0_O

#=======================
#  Time-stepping
#=======================
while (t < T):
  t += dt
  tlevel += 1

  # Update for time
  fpsiA1_expression.t = t
  fpsiB1_expression.t = t
  f_A.interpolate(fpsiA1_expression)
  f_B.interpolate(fpsiB1_expression)
  psiA1_expression.t = t
  psiB1_expression.t = t
  cA_expression.t = t
  cB_expression.t = t
  cC_expression.t = t
  exact_psiA.interpolate(psiA1_expression)
  exact_psiB.interpolate(psiB1_expression)
  exact_cA.interpolate(cA_expression)
  exact_cB.interpolate(cB_expression)
  exact_cC.interpolate(cC_expression)
  if rank == 0:
    print '\tTime = %f' % t
  
  # Advection-diffusion
  with PETSc.Log.Stage("Advection"):
    if opt_OS == 2:    
      b = ADassemble(L_A,bc_A)
      u0 = ADsolve(solver_A,u0,b)
  with PETSc.Log.Stage("Diffusion"):
    if opt_OS >= 1:
      b = ADassemble(L_D,bc_D)
      u0 = ADsolve(solver_D,u0,b)
  with PETSc.Log.Stage("Advection"):
    b = ADassemble(L_A,bc_A)
    u0 = ADsolve(solver_A,u0,b)
  with PETSc.Log.Stage("Diffusion"):
    b = ADassemble(L_D,bc_D)
    u0 = ADsolve(solver_D,u0,b)
  
  # Reactions
  for i in printat:
    if i == tlevel or i == 0:      
      with u0_A.dat.vec as psiA_vec, u0_B.dat.vec as psiB_vec:
        psiA_vec.chop(TOL)
        psiB_vec.chop(TOL)
      op2.par_loop(reactions3_kernel,Q.dof_dset,u0_A.dat(op2.READ),u0_B.dat(op2.READ),c_A.dat(op2.WRITE),c_B.dat(op2.WRITE),c_C.dat(op2.WRITE),k1.dat(op2.READ))
      with c_A.dat.vec as cA_vec, c_B.dat.vec as cB_vec, c_C.dat.vec as cC_vec:
        cA_vec.chop(TOL)
        cB_vec.chop(TOL)
        cC_vec.chop(TOL)
      
      # Calculate L2 error norm
      L2_psiA += norm(assemble(u0_A-exact_psiA))
      L2_psiB += norm(assemble(u0_B-exact_psiB))
      L2_cA += norm(assemble(c_A-exact_cA))
      L2_cB += norm(assemble(c_B-exact_cB))
      L2_cC += norm(assemble(c_C-exact_cC))

      # Output images
      outFile_exact_psiA << exact_psiA
      outFile_exact_psiB << exact_psiB
      outFile_exact_cA << exact_cA
      outFile_exact_cB << exact_cB
      outFile_exact_cC << exact_cC
      outFile_psiA << u0_A
      outFile_psiB << u0_B
      outFile_A << c_A
      outFile_B << c_B
      outFile_C << c_C
      
      # Output text files
      #np.savetxt(__file__.rsplit('.',1)[0]+"/out_psiA_"+str(tlevel),u0_A.vector().array().reshape(seed+1,),fmt="%f")
      #np.savetxt(__file__.rsplit('.',1)[0]+"/out_psiB_"+str(tlevel),u0_B.vector().array().reshape(seed+1,),fmt="%f")

#=======================
#  End simulation
#=======================
getTime = time.time()-initialTime
if rank == 0:
  print '\tWall-clock time: %.3e seconds' % getTime
  print '\tL2 error norm - psiA: %.3e' % float(L2_psiA/tlevel)
  print '\tL2 error norm - psiB: %.3e' % float(L2_psiB/tlevel)
  print '\tL2 error norm - cA: %.3e' % float(L2_cA/tlevel)
  print '\tL2 error norm - cB: %.3e' % float(L2_cB/tlevel)
  print '\tL2 error norm - cC: %.3e' % float(L2_cC/tlevel)
