#===============================================================
#
#  Density driven flow:
#    Original Darcy
#    Homogeneous domain
#    Isotropic medium
#
#  Run as:
#    python 2D_Darcy_Homo_Iso.py <xseed> <yseed> <printat>
#
#  <x/yseed> = number of elements in x/y direction
#  <printat> = Interval to Output vtu file
#
#===============================================================
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np, sys, time
rank = PETSc.COMM_WORLD.getRank()

# Required command-line options
xseed = int(sys.argv[1])
yseed = int(sys.argv[2])
printat = int(sys.argv[3])

# Time stepping (days)
tstep = 86400
T = 365*86400
t = 0.0

#===============#
#  Create mesh  #
#===============#
Lx = 600
Ly = 150
mesh = RectangleMesh(xseed,yseed,Lx,Ly)

#=====================#
#  Custom boundaries  #
#=====================#
plex = mesh._plex
coords = plex.getCoordinates()
coord_sec = plex.getCoordinateSection()
if plex.getStratumSize("boundary_faces",1) > 0:
  faces = plex.getStratumIS("boundary_faces",1).getIndices()
  xtol = float(Lx)/(2*float(xseed))
  ytol = float(Ly)/(2*float(yseed))
  xsize = xtol*2
  ysize = ytol*2
  for face in faces:
    face_coords = plex.vecGetClosure(coord_sec,coords,face)
    
    # Top hole - ID 5
    if abs(face_coords[1] - Ly) < ytol and abs(face_coords[3] - Ly) < ytol and \
        abs(face_coords[0]) >= 150 and abs(face_coords[2]) >= 150 and \
        abs(face_coords[0]) <= 450 and abs(face_coords[2]) <= 450:
      plex.setLabelValue("boundary_ids", face, 5)

    # Corners - ID 6
    if (abs(face_coords[0]) < xtol and abs(face_coords[2]) < xtol and \
        abs(face_coords[1] - Ly) <= ysize and abs(face_coords[3] - Ly) <= ysize) or \
        (abs(face_coords[0] - Lx) < xtol and abs(face_coords[2] - Lx) < xtol and \
        abs(face_coords[1] - Ly) <= ysize and abs(face_coords[3] - Ly) <= ysize) or \
        (abs(face_coords[1]) < ytol and abs(face_coords[3]) < ytol and \
        abs(face_coords[0] - Lx) <= xsize and abs(face_coords[2] - Lx) <= xsize) or \
        (abs(face_coords[1] - Ly) < ytol and abs(face_coords[3] - Ly) < ytol and \
        abs(face_coords[0] - Lx) <= xsize and abs(face_coords[2] - Lx) <= xsize):
      plex.setLabelValue("boundary_ids", face, 6)

# Velocity
V = FunctionSpace(mesh, "RT", 1)
# Pressure
P = FunctionSpace(mesh, 'DG', 0)
# Concentration
Q = FunctionSpace(mesh, 'DG', 1)

# Flow
W = V*P
v,p = TrialFunctions(W)
w,q = TestFunctions(W)
u = Function(W)
vel,pres = u.split()

# Transport
c = TrialFunction(Q)
d = TestFunction(Q)
conc = Function(Q, name="solution")

# Initial concentration
c0 = Function(Q)
c0.assign(0)

#================#
#  Permeability  #
#================#
base_perm = 1e-12 #m^2
k = Constant(base_perm)

#=======================#
#  Specific body force  #
#=======================#
density1 = 1000 # kg/m^2
density2 = 200 # kg/m^2
rho1 = Constant(density1)
rho2 = Constant(density2)
rho = rho1 + rho2*c0
g = project(Expression(("0.0","-9.81")),V)

#=============#
#  Viscosity  #
#=============#
dynamic_viscosity = 1e-6 # m^2/s
nu = Constant(dynamic_viscosity)
mu = nu*rho

#=====================#
#  Volumetric source  #
#=====================#
f = Constant(0.0) 

#===========================#
#  Diffusivity coefficient  #
#===========================#
D = Constant(3.565e-6) 

#=======================#
#  Boundary conditions  #
#=======================#
patm = Constant(101325) # Pascals
bcflow1 = DirichletBC(W.sub(0), Constant("0.0"), (1,2,3,4))
bcflow = [bcflow1]
bctran1 = DirichletBC(Q, Constant(0.0), 3, method="geometric")
bctran2 = DirichletBC(Q, Constant(1.0), 5, method="geometric")
bctran = [bctran1,bctran2]

#=====================#
#  Weak formulations  #
#=====================#
n = FacetNormal(mesh)
h = Constant(Lx/float(xseed))
gamma = Constant(4)
dt = Constant(tstep)

# Flow problem
a_flow = (dot(w,mu/k*v) - div(w)*p - q*div(v))*dx
L_flow = q*f*dx + dot(w,rho*g)*dx - dot(w,n)*patm*ds(6)

# Transport problem
vn = 0.5*(dot(vel,n) + abs(dot(vel,n)))
a_tran = d*c/dt*dx \
  + inner(grad(d), D*grad(c)) * dx \
  - dot(jump(d,n),avg(D*grad(c)))*dS \
  - dot(avg(D*grad(d)),jump(c,n))*dS \
  + gamma/h*dot(jump(d,n),jump(c,n))*dS \
  - dot(grad(d),vel*c)*dx \
  + dot(jump(d),vn('+')*c('+')-vn('-')*c('-'))*dS \
  + dot(d, vn*c)*ds(3)
L_tran = d * f * dx + d*c0/dt*dx

#=================#
#  Solver params  #
#=================#
flow_parameters = {
  'ksp_type': 'gmres',
  'pc_type': 'fieldsplit',
  'pc_fieldsplit_type': 'schur',
  'pc_fieldsplit_schur_fact_type': 'full',
  'pc_fieldsplit_schur_precondition': 'selfp',
  'fieldsplit_0_ksp_type': 'preonly',
  'fieldsplit_0_pc_type': 'ilu',
  'fieldsplit_1_ksp_type': 'preonly',
  'fieldsplit_1_pc_type': 'hypre',
  'fieldsplit_1_pc_hypre_type': 'boomeramg',
#  'fieldsplit_1_pc_hypre_boomeramg_strong_threshold': 0.75,
#  'fieldsplit_1_pc_hypre_boomeramg_agg_nl': 2,
#  'fieldsplit_1_pc_type': 'ml',
  'ksp_rtol': 1e-7,
  'ksp_atol': 1e-50,
  'ksp_converged_reason': True
}
tran_parameters = {
  'ksp_type': 'gmres',
  'pc_type': 'hypre',
  'pc_hypre_type': 'boomeramg',
  'pc_hypre_boomeramg_strong_threshold': 0.75,
  'pc_hypre_boomeramg_agg_nl': 2,
  'ksp_rtol': 1e-7,
  'ksp_atol': 1e-50,
  'ksp_converged_reason': True
}

#=================#
#  Setup solvers  #
#=================#
problem_flow = LinearVariationalProblem(a_flow, L_flow, u, bcs=bcflow,
  constant_jacobian=False)
solver_flow = LinearVariationalSolver(problem_flow, options_prefix="flow_",
  solver_parameters=flow_parameters)
problem_tran = LinearVariationalProblem(a_tran, L_tran, conc, bcs=bctran,
  constant_jacobian=False)
solver_tran = LinearVariationalSolver(problem_tran, options_prefix="tran_",
  solver_parameters=tran_parameters)

#==========#
#  Output  #
#==========#
file_prefix = 'Figures/'+ __file__.rsplit('.',1)[0]+'_'+str(xseed)+'_x_'+str(yseed)
outfile = File(file_prefix+'.pvd')

#=================#
#  Solve problem  #
#=================#
printcount = 0
while t < T:
  t += tstep
  if rank == 0:
    print '\tTime level: %.3e hours' % t
  
  # Original
  # Flow solver
  u.assign(0)
  initialTime = time.time()
  solver_flow.solve()
  
  # Transport solver
  conc.assign(0)
  solver_tran.solve()
  
  if rank == 0:
    finalTime = time.time() - initialTime
    print 'Wall-clock time: %.3e seconds' % finalTime

  # Output
  printcount += 1
  if printat == printcount:
    printcount = 0
    outfile.write(conc,vel,pres)
 
  # Update
  c0.assign(conc)

