#===============================================================
#
#  Density driven flow:
#    Original Darcy
#    Homogeneous domain
#    Isotropic medium
#
#  Run as:
#    python 2D_Darcy_Homo_Iso.py <xseed> <yseed> <printat>
#
#  <x/yseed> = number of elements in x/y direction
#  <printat> = Interval to Output vtu file
#
#===============================================================
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np, sys, time
rank = PETSc.COMM_WORLD.getRank()

# Required command-line options
xseed = int(sys.argv[1])
yseed = int(sys.argv[2])
printat = int(sys.argv[3])

# Time stepping (years)
delta_t = 0.05
T = 10

#===============#
#  Create mesh  #
#===============#
Lx = 600
Ly = 150
mesh = RectangleMesh(xseed,yseed,Lx,Ly)

#=====================#
#  Custom boundaries  #
#=====================#
plex = mesh._plex
coords = plex.getCoordinates()
coord_sec = plex.getCoordinateSection()
if plex.getStratumSize("boundary_faces",1) > 0:
  plex.removeLabel("boundary_ids")
  faces = plex.getStratumIS("boundary_faces",1).getIndices()
  xtol = float(Lx)/(2*float(xseed))
  ytol = float(Ly)/(2*float(yseed))
  xsize = xtol*2
  ysize = ytol*2
  for face in faces:
    face_coords = plex.vecGetClosure(coord_sec,coords,face)
    
    # Corners - ID 6
    if (abs(face_coords[0]) < xtol and abs(face_coords[2]) < xtol and \
        abs(face_coords[1]) <= ysize and abs(face_coords[3]) <= ysize) or \
        (abs(face_coords[0] - Lx) < xtol and abs(face_coords[2] - Lx) < xtol and \
        abs(face_coords[1]) <= ysize and abs(face_coords[3]) <= ysize) or \
        (abs(face_coords[1]) < ytol and abs(face_coords[3]) < ytol and \
        abs(face_coords[0] - Lx) <= xsize and abs(face_coords[2] - Lx) <= xsize) or \
        (abs(face_coords[1]) < ytol and abs(face_coords[3]) < ytol and \
        abs(face_coords[0]) <= xsize and abs(face_coords[2]) <= xsize):
      plex.setLabelValue("boundary_ids", face, 6)
    else:
      if abs(face_coords[0]) < xtol and abs(face_coords[2]) < xtol:
        plex.setLabelValue("boundary_ids", face, 1)
      if abs(face_coords[0]-Lx) < xtol and abs(face_coords[2]-Lx) < xtol:
        plex.setLabelValue("boundary_ids", face, 2)
      if abs(face_coords[1]) < ytol and abs(face_coords[3]) < ytol:
        plex.setLabelValue("boundary_ids", face, 3)

      # Top hole - ID 5
      if abs(face_coords[1] - Ly) < ytol and abs(face_coords[3] - Ly) < ytol and \
          abs(face_coords[0]) >= 150 and abs(face_coords[2]) >= 150 and \
          abs(face_coords[0]) <= 450 and abs(face_coords[2]) <= 450:
        plex.setLabelValue("boundary_ids", face, 5)
      elif abs(face_coords[1] - Ly) < ytol and abs(face_coords[3] - Ly) < ytol:
        plex.setLabelValue("boundary_ids", face, 4)


# Velocity
V = FunctionSpace(mesh, "RT", 1)
# Pressure
P = FunctionSpace(mesh, 'DG', 0)
# Concentration
Q = FunctionSpace(mesh, 'DG', 1)

# Flow
W = V*P
v,p = TrialFunctions(W)
w,q = TestFunctions(W)
u = Function(W)
vel,pres = u.split()

# Transport
c = TrialFunction(Q)
d = TestFunction(Q)
conc = Function(Q, name="solution")

# Initial concentration
c0 = Function(Q)
c0.project(Expression(("x[0] >= 150 && x[0] <= 450 && x[1] >= 145 ? sin(0.5*x[0])*sin(0.5*x[0]) : 0.0")))
#c0.assign(0)

#================#
#  Permeability  #
#================#
#base_perm = 1e-12 #m^2
base_perm = 4.845e-13 #m^2
k = Constant(base_perm)

#============#
#  Porosity  #
#============#
porosity = 0.1
phi = Constant(porosity)

#=======================#
#  Specific body force  #
#=======================#
density1 = 1000 # kg/m^2
density2 = 200 # kg/m^2
rho1 = Constant(density1)
rho2 = Constant(density2)
rho = rho1 + rho2*c0
g = Function(V).project(Expression(("0.0","-9.81")))

#=============#
#  Viscosity  #
#=============#
dynamic_viscosity = 1e-6 # m^2/s
nu = Constant(dynamic_viscosity)
mu = nu*rho
#mu = Constant(1e-3)

#=====================#
#  Volumetric source  #
#=====================#
f = Constant(0.0) 

#===========================#
#  Diffusivity coefficient  #
#===========================#
D = Constant(3.565e-6) # m^2/s

#=======================#
#  Boundary conditions  #
#=======================#
patm = Constant(101325) # Pascals
bcflow1 = DirichletBC(W.sub(0), Constant("0.0"), (1,2,3,4,5,6))
bcflow = [bcflow1]
bctran1 = DirichletBC(Q, Constant(0.0), 3, method="geometric")
bctran2 = DirichletBC(Q, Constant(1.0), 5, method="geometric")
bctran = [bctran1,bctran2]

#=====================#
#  Weak formulations  #
#=====================#
n = FacetNormal(mesh)
h = Constant(Lx/float(xseed))
gamma = Constant(4)
tstep = delta_t*86400*365
dt = Constant(tstep)

# Flow problem
a_flow = (dot(w,mu/k*v) - div(w)*p - q*div(v))*dx
L_flow = q*f*dx + dot(w,rho*g)*dx - dot(w,n)*patm*ds(6)

# Transport problem
vn = 0.5*(dot(vel,n) + abs(dot(vel,n)))
a_tran = phi*d*c/dt*dx \
  + inner(grad(d), phi*D*grad(c)) * dx \
  - dot(jump(d,n),avg(phi*D*grad(c)))*dS \
  - dot(avg(phi*D*grad(d)),jump(c,n))*dS \
  + gamma/h*dot(jump(d,n),jump(c,n))*dS \
  - dot(grad(d),vel*c)*dx \
  + dot(jump(d),vn('+')*c('+')-vn('-')*c('-'))*dS \
  + dot(d, vn*c)*(ds(6)+ds(3))
L_tran = d * f * dx + phi*d*c0/dt*dx

#=================#
#  Solver params  #
#=================#
flow_parameters = {
  'ksp_type': 'gmres',
  'pc_type': 'fieldsplit',
  'pc_fieldsplit_type': 'schur',
  'pc_fieldsplit_schur_fact_type': 'full',
  'pc_fieldsplit_schur_precondition': 'selfp',
  'fieldsplit_0_ksp_type': 'preonly',
  'fieldsplit_0_pc_type': 'ilu',
  'fieldsplit_1_ksp_type': 'preonly',
#  'fieldsplit_1_pc_type': 'hypre',
#  'fieldsplit_1_pc_hypre_type': 'boomeramg',
#  'fieldsplit_1_pc_hypre_boomeramg_strong_threshold': 0.75,
#  'fieldsplit_1_pc_hypre_boomeramg_agg_nl': 2,
  'fieldsplit_1_pc_type': 'ml',
  'ksp_rtol': 1e-6,
  'ksp_atol': 1e-50,
  'ksp_converged_reason': True
}
tran_parameters = {
  'ksp_type': 'gmres',
  'pc_type': 'hypre',
  'pc_hypre_type': 'boomeramg',
  'pc_hypre_boomeramg_strong_threshold': 0.75,
  'pc_hypre_boomeramg_agg_nl': 2,
#  'pc_type': 'gamg',
  'ksp_rtol': 1e-5,
  'ksp_atol': 1e-50,
  'ksp_converged_reason': True
}

#=================#
#  Setup solvers  #
#=================#
problem_flow = LinearVariationalProblem(a_flow, L_flow, u, bcs=bcflow,
  constant_jacobian=False)
solver_flow = LinearVariationalSolver(problem_flow, options_prefix="flow_",
  solver_parameters=flow_parameters)
problem_tran = LinearVariationalProblem(a_tran, L_tran, conc, bcs=bctran,
  constant_jacobian=False)
solver_tran = LinearVariationalSolver(problem_tran, options_prefix="tran_",
  solver_parameters=tran_parameters)

#==========#
#  Output  #
#==========#
if printat > 0:
  file_prefix = 'Figures/'+ __file__.rsplit('.',1)[0]+'_'+str(xseed)+'_x_'+str(yseed)
  outfile = File(file_prefix+'.pvd')

#=================#
#  Solve problem  #
#=================#
printcount = 0
t = 0.0
while t < T:
  t += delta_t
  if rank == 0:
    print '\tTime level: %.3e years' % t
  
  # Flow solver
  u.assign(0)
  initialTime = time.time()
  solver_flow.solve()
  
  # Transport solver
  conc.assign(0)
  solver_tran.solve()
  
  if rank == 0:
    finalTime = time.time() - initialTime
    print 'Wall-clock time: %.3e seconds' % finalTime

  # Output
  printcount += 1
  if printat == printcount:
    printcount = 0
    outfile.write(conc,vel,pres)
 
  # Update
  c0.assign(conc)

#=============#
#  Visualize  #
#=============#
try:
  import matplotlib.pyplot as plt
except:
  warning("Matplotlib not imported")
try:
  plot(conc)
except Exception as e:
  warning("Cannot plot figure. Error msg '%s'" % e.message)
try:
  plt.show()
except Exception as e:
  warning("Cannot show figure. Error msg '%s'" % e.message)
