#================================================
#
#  Darcys equation - RT0 formulation
#  By: Justin Chang
#
#  v_x = cos(pi*x)*sin(pi*y)
#  v_y = -sin(pi*x)*cos(pi*y)
#  p   = sin(2*pi*x)*sin(2*pi*y)
#
#================================================
import sys
seed = int(sys.argv[1])

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
from mpi4py import MPI
import numpy as np
import time
comm = MPI.COMM_WORLD
rank = comm.Get_rank()

#========================
#  Create mesh 
#========================
mesh = UnitSquareMesh(seed, seed)
V = FunctionSpace(mesh, "RT", 1)
Q = FunctionSpace(mesh, "DG", 0)
W = V * Q
v, p = TrialFunctions(W)
w, q = TestFunctions(W)

#========================
#  Weak form
#========================
g = Function(V)
g.project(Expression(("cos(pi*x[0])*sin(pi*x[1])+2*pi*cos(2*pi*x[0])*sin(2*pi*x[1])","-sin(pi*x[0])*cos(pi*x[1])+2*pi*sin(2*pi*x[0])*cos(2*pi*x[1])")))
a = dot(w,v)*dx - div(w)*p*dx - q*div(v)*dx
L = dot(w,g)*dx

#========================
#  Solver parameters
#========================
selfp_parameters = {
    #'ksp_monitor_true_residual': True,
    'ksp_type': 'gmres',
    # Upper Schur factorisation
    # Precondition S = D - C A^{-1} B with Sp = D - C diag(A)^{-1} B
    'pc_type': 'fieldsplit',
    'pc_fieldsplit_type': 'schur',
    'pc_fieldsplit_schur_fact_type': 'upper',
    'pc_fieldsplit_schur_precondition': 'selfp',
    # Approximately invert A with a single application of
    # process-block ILU(0)
    'fieldsplit_0_ksp_type': 'preonly',
    'fieldsplit_0_pc_type': 'bjacobi',
    'fieldsplit_0_sub_pc_type': 'ilu',
    # Approximately invert S with a single V-cycle on Sp
    # This means we never apply the action of the schur complement and
    # let the outer GMRES iteration fix things up
    'fieldsplit_1_ksp_type': 'preonly',
    'fieldsplit_1_pc_type': 'hypre'  
    }

#========================
#  Solve problem
#========================
solution = Function(W)
initialTime = time.time()
A = assemble(a)
b = assemble(L)
solver = LinearSolver(A,options_prefix="selfp_",solver_parameters=selfp_parameters)
solver.solve(solution,b)
getTime = time.time()-initialTime

#=======================
#  Compute L2 error
#=======================
v,p = solution.split()
velocity = Expression(("cos(pi*x[0])*sin(pi*x[1])","-sin(pi*x[0])*cos(pi*x[1])"))
pressure = Expression("sin(2*pi*x[0])*sin(2*pi*x[1])")
v_exact = Function(V)
p_exact = Function(Q)
v_exact.project(velocity)
p_exact.project(pressure)
L2_v_error = sqrt(assemble(dot(v-v_exact,v-v_exact)*dx))
L2_p_error = sqrt(assemble(dot(p-p_exact,p-p_exact)*dx))

#========================
#  Output solutions
#========================
if rank == 0:
  print '\tL2 error velocity: %.3e' % L2_v_error
  print '\tL2 error pressure: %.3e' % L2_p_error
  print '\tWall-clock time: %.3e seconds' % getTime
File("Figures/RT0_pressure.pvd") << p
File("Figures/RT0_velocity.pvd") << v
