#================================================
#
#  DG Test
#  By: Justin Chang
#
#  2A + B <=> 2C
#
#  Run as:
#    python DG_test.py <xseed> <yseed>
#
#  <x/yseed> = number of elements in each direction
#
#================================================
#printat = {0}
printat = {1,10,100}
alpha_penalty = 40
gamma_penalty = 8
dt = 0.01                  # Global time-step
T = 1.00                   # Maximum time
aD = 1e-6                  # Molecular diffusion
at = 1e-4                  # Tranverse dispersivity
aL = 1                     # Longitudinal dispersivity
tort = 1                   # Tortuosity
TOL = 1e-7                # Tolerance

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np
import time
import sys
import math
rank = PETSc.COMM_WORLD.getRank()
parameters["matnest"] = False
op2.init(log_level='WARNING')
xseed = int(sys.argv[1])
yseed = int(sys.argv[2])

#========================
#  Create mesh 
#========================
mesh = RectangleMesh(xseed,yseed,2,1,quadrilateral=True)
V = VectorFunctionSpace(mesh, "DG", 1)
Q = FunctionSpace(mesh, "DG", 1)
W = Q*Q
u_A,u_B = TrialFunctions(W)
v_A,v_B = TestFunctions(W)
u0 = Function(W)
u0_A,u0_B = u0.split()
bcIn = Function(W)
bcInA,bcInB = bcIn.split()

#========================
#  Advection velocity
#========================
velocity = interpolate(Expression(("1 + 0.08*pi*cos(4*pi*x[0]*0.5 - 0.5*pi)*cos(pi*x[1]) + 0.1*pi*cos(5*pi*x[0]*0.5 - 0.5*pi)*cos(5*pi*x[1]) + 0.1*pi*cos(10*pi*x[0]*0.5 - 0.5*pi)*cos(10*pi*x[1])","0.16*pi*sin(2*pi*x[0] - 0.5*pi)*sin(pi*x[1]) + 0.05*pi*sin(5*pi*x[0]*0.5 - 0.5*pi)*sin(5*pi*x[1]) + 0.05*pi*sin(5*pi*x[0] - 0.5*pi)*sin(10*pi*x[1])")),V)

#========================
#  Dispersion tensor
#========================
normv = sqrt(dot(velocity,velocity))
Id = Identity(mesh.ufl_cell().geometric_dimension())
alpha_t = Constant(at)
alpha_L = Constant(aL)
alpha_D = Constant(aD*tort)
D = (alpha_D + alpha_t*normv)*Id+(alpha_L-alpha_t)*outer(velocity,velocity)/normv

#========================
#  Volumetric source
#========================
f_A = Function(Q)
f_B = Function(Q)
f_A_factor = Constant(0.5)
f_D_factor = Constant(0.5)
f_A.interpolate(Expression("0.0"))
f_B.interpolate(Expression("0.0"))

#========================
#  Time-stepping
#========================
tscale_D = 1.0
dt_D = Constant(dt*tscale_D)
t = 0
tlevel = 0

#========================
#  Solver parameters
#========================
OS_parameters = {
  'ksp_type': 'gmres',
  'pc_type': 'hypre',
  'mat_type': 'aij'
  #'pc_type': 'hypre',
  #'pc_hypre_type': 'boomeramg'
}

#========================
#  Output files
#========================
outFile_velocity = File('Figures/' + __file__.rsplit('.',1)[0]+"/vel.pvd")
outFile_velocity.write(velocity)
outFile_psi = File('Figures/' + __file__.rsplit('.',1)[0] + '/sol.pvd' )

#========================
#  Weak form
#========================
n = FacetNormal(mesh)
h = Constant(1/yseed)
#h = CellSize(mesh)
alpha = Constant(alpha_penalty)
gamma = Constant(gamma_penalty)
h_avg = Constant(1/yseed) #0.5*(h('+')+h('-'))
bcInA.interpolate(Expression("x[1] < 0.5 ? 1.0 : 0.0"))
bcInB.interpolate(Expression("x[1] > 0.5 ? 1.0 : 0.0"))

# Diffusion weak form
a_D_psiA = v_A*u_A*dx + dt_D*(dot(grad(v_A),D*grad(u_A))*dx \
    - dot(jump(v_A,n),avg(D*grad(u_A)))*dS - dot(avg(D*grad(v_A)),jump(u_A,n))*dS \
    + alpha/h_avg*dot(jump(v_A,n),jump(u_A,n))*dS) # + gamma/h*v_A*u_A*ds \
    #- dot(v_A*n,grad(u_A))*ds - dot(grad(v_A),u_A*n)*ds)
a_D_psiB = v_B*u_B*dx + dt_D*(dot(grad(v_B),D*grad(u_B))*dx \
    - dot(jump(v_B,n),avg(D*grad(u_B)))*dS - dot(avg(D*grad(v_B)),jump(u_B,n))*dS \
    + alpha/h_avg*dot(jump(v_B,n),jump(u_B,n))*dS) # + gamma/h*v_B*u_B*ds \
    #- dot(v_B*n,grad(u_B))*ds - dot(grad(v_B),u_B*n)*ds)
L_D_psiA = v_A*u0_A*dx + dt_D*(v_A*f_D_factor*f_A)*dx #\
    #+ bcInA*(v_A*gamma/h - dot(D*grad(v_A),n))*ds(1)) #\
    #+ bcOutA*(v_A*gamma/h - dot(D*grad(v_A),n))*ds(2))
L_D_psiB = v_B*u0_B*dx + dt_D*(v_B*f_D_factor*f_B)*dx #\
    #+ bcInB*(v_B*gamma/h - dot(D*grad(v_B),n))*ds(1)) #\
    #+ bc_B_out*(v_B*gamma/h - dot(D*grad(v_B),n))*ds(2))
a_D = a_D_psiA + a_D_psiB
L_D = L_D_psiA + L_D_psiB

#========================
#  Dirichlet BCs
#========================
bc1 = DirichletBC(W.sub(0), Expression(("x[1] > 0.0 && x[1] <= 0.5 ? 1.0 : 0")), (1), method="geometric")
bc2 = DirichletBC(W.sub(1), Expression(("x[1] > 0.5 && x[1] < 1.0 ? 1.0 : 0.0")), (1), method="geometric")
bc3 = DirichletBC(W.sub(0), 0.0, (2,3,4), method="geometric")
bc4 = DirichletBC(W.sub(1), 0.0, (2,3,4), method="geometric")
bc_D = [bc1,bc2,bc3,bc4]

#========================
#  Initial condition
#========================
u0.assign(0)

#=======================
#  Time-stepping
#=======================
initialTime = time.time()
while (t < T):
  t += dt
  tlevel += 1
  if rank == 0:
    print '\tTime = %0.3f' % t
  
  with PETSc.Log.Stage("Diffusion"):
    solve(a_D==L_D,u0,solver_parameters=OS_parameters,options_prefix="D_")
  
  for i in printat:
    if i == tlevel or i == 0:      
      outFile_psi.write(u0_A,u0_B)

#=======================
#  End simulation
#=======================
getTime = time.time()-initialTime
if rank == 0:
  print '\tWall-clock time: %.3e seconds' % getTime
