#=====================================================================
#  Darcy - LS formulation
#  By Justin Chang:
#
#  Usage: python Darcy_FE.py <seed> <nest>
#
#    <seed> = Number of cells in all three spatial directions
#    <nest> = Use monolithic (0) or nested system (1)
#=====================================================================

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np
import time
import sys
seed = int(sys.argv[1])
nest = int(sys.argv[2])
comm = PETSc.COMM_WORLD
rank = PETSc.COMM_WORLD.getRank()
size = PETSc.COMM_WORLD.getSize()

#========================
#  Discretization 
#========================
mesh = UnitCubeMesh(seed, seed, seed)
V = VectorFunctionSpace(mesh,"CG",2)
Q = FunctionSpace(mesh,"CG", 1)
W = V * Q
v, p = TrialFunctions(W)
w, q = TestFunctions(W)

#========================
#  Forcing function 
#========================
f = Function(Q)
f.interpolate(Expression("12*pi*pi*sin(pi*x[0]*2)*sin(pi*x[1]*2)*sin(2*pi*x[2])"))

#========================
#  Weak form 
#========================
a = dot(v+grad(p),w+grad(q))*dx + div(v)*div(w)*dx
L = f*div(w)*dx
 
#========================
#  solver parameters
#========================
if nest:
  parameters["matnest"] = True
  solver_parameters = {
    'ksp_type': 'cg',
    'pc_type': 'fieldsplit',
    'pc_fieldsplit_type': 'multiplicative',
#    'pc_fieldsplit_type': 'schur',
#    'pc_fieldsplit_schur_fact_type': 'full',
#    'pc_fieldsplit_schur_precondition': 'selfp',
#    'fieldsplit_0_ksp_type': 'cg',
    'fieldsplit_0_ksp_type': 'preonly',
    'fieldsplit_0_pc_type': 'hypre',
#    'fieldsplit_0_pc_type': 'jacobi',
#    'fieldsplit_0_sub_pc_type': 'ilu',
#    'fieldsplit_0_pc_hypre_type': 'boomeramg',
#    'fieldsplit_1_ksp_type': 'cg',
    'fieldsplit_1_ksp_type': 'preonly',
    'fieldsplit_1_pc_type': 'hypre',
#    'fieldsplit_1_pc_hypre_type': 'boomeramg',
    'ksp_monitor_true_residual': True,
    'ksp_converged_reason': True
  }
else:
  parameters["matnest"] = False
  solver_parameters = {
    'ksp_type':'cg',
    #'pc_type':'gamg',
    'pc_type':'hypre',
    'pc_hypre_type': 'boomeramg',
    'ksp_monitor_true_residual': True,
    'ksp_converged_reason': True
  }

#========================
#  Solve problem
#========================
solution = Function(W)
initialTime = time.time()
if rank == 0:
  print 'MPI processes %d: solving... ' % size
 #bc1 = DirichletBC(W.sub(0).sub(0), Expression("2*pi*cos(2*pi*x[0])*sin(2*pi*x[1])*sin(2*pi*x[2])"), (1,2))
 #bc2 = DirichletBC(W.sub(0).sub(1), Expression("2*pi*sin(2*pi*x[0])*cos(2*pi*x[1])*sin(2*pi*x[2])"), (3,4))
 #bc3 = DirichletBC(W.sub(0).sub(2), Expression("2*pi*sin(2*pi*x[0])*sin(2*pi*x[1])*cos(2*pi*x[2])"), (5,6))
bc4 = DirichletBC(W.sub(1),0,(1,2,3,4,5,6))
bcs = bc4
#bcs = [bc1,bc2,bc3]
#bcs = [bc1,bc2,bc3,bc4]
A = assemble(a,bcs=bcs)
b = assemble(L,bcs=bcs)
v_basis = VectorSpaceBasis(constant=True)
nullspace = MixedVectorSpaceBasis(W, [W.sub(0),v_basis])
solver = LinearSolver(A,solver_parameters=solver_parameters,options_prefix="solver_",nullspace=nullspace)
solver.solve(solution,b)
endTime = time.time()

#=======================
#  Performance metrics
#=======================
if rank == 0:
  elapsedTime = endTime-initialTime
  print A.M.handle.getSizes() 
  print '\tSolver time: %e' % elapsedTime 
  print '\tSolver iterations: %d' % solver.ksp.getIterationNumber()

#=======================
#  Computer L2-norm
#=======================
comm.Barrier()
v, p = solution.split()

#========================
#  Output solutions
#========================
File("Figures/LS_pressure.pvd") << p
File("Figures/LS_velocity.pvd") << v
