# User parameters
import sys
seed = int(sys.argv[1])
useCounters = False
calculateError = True
printSolution = False

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
from mpi4py import MPI
import pypapi
import numpy as np
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
if useCounters:
    events = np.asarray([ \
      pypapi.Event.L1_DCM, \
      pypapi.Event.LD_INS, \
      pypapi.Event.SR_INS \
      , pypapi.Event.FP_OPS \
      ],dtype=pypapi.EventType)
if rank == 0:
  print 'Discretization: RT0'
#========================
#  Discretization 
#========================
mesh = UnitCubeMesh(seed, seed, seed)
V = FunctionSpace(mesh, "RT", 1)
Q = FunctionSpace(mesh, "DG", 0)
W = V * Q
v, p = TrialFunctions(W)
w, q = TestFunctions(W)
f = Function(Q)
f.interpolate(Expression("12*pi*pi*sin(pi*x[0]*2)*sin(pi*x[1]*2)*sin(2*pi*x[2])"))
a = (dot(v, w) - p*div(w) - div(v)*q)*dx
L = f*q*dx

#========================
#  Parameters
#========================
selfp_parameters = {
    'ksp_type': 'gmres',
    'ksp_monitor_true_residual': True,
    # Upper Schur factorisation
    # Precondition S = D - C A^{-1} B with Sp = D - C diag(A)^{-1} B
    'pc_type': 'fieldsplit',
    'pc_fieldsplit_type': 'schur',
    'pc_fieldsplit_schur_fact_type': 'upper',
    'pc_fieldsplit_schur_precondition': 'selfp',
    # Approximately invert A with a single application of
    # process-block ILU(0)
    'fieldsplit_0_ksp_type': 'preonly',
    'fieldsplit_0_pc_type': 'bjacobi',
    'fieldsplit_0_sub_pc_type': 'ilu',
    # Approximately invert S with a single V-cycle on Sp
    # This means we never apply the action of the schur complement and
    # let the outer GMRES iteration fix things up
    'fieldsplit_1_ksp_type': 'preonly',
    'fieldsplit_1_pc_type': 'hypre'
    }

#========================
#  Solve problem
#========================
solution = Function(W)
if useCounters:
  reduceCounts = np.zeros(4, dtype=pypapi.CountType)
  counts = np.zeros(4, dtype=pypapi.CountType)
  pypapi.start_counters(events)
initialTime = pypapi.get_real_usec()
if rank == 0:
  print 'MPI processes %d: solving... ' % size
with PETSc.Log.Stage("selfp"):
  A = assemble(a)
  b = assemble(L)
  solver = LinearSolver(A,solver_parameters=selfp_parameters,options_prefix="selfp_")
  solver.solve(solution,b)
endTime = pypapi.get_real_usec()
if useCounters:
  pypapi.stop_counters(counts)

#=======================
#  Performance metrics
#=======================
if useCounters:
  comm.Allreduce(counts,reduceCounts,op=MPI.SUM)
if rank == 0:
  elapsedTime = float(endTime-initialTime)*1e-6
  print A.M.handle.getSizes()
  print '\tSolver time: %e' % elapsedTime 
  print '\tSolver iterations: %d' % solver.ksp.getIterationNumber()
  if useCounters:
    reduceCounts = reduceCounts
    FLOPSS = reduceCounts[3]/elapsedTime
    AI = float(reduceCounts[3])/float(reduceCounts[2]+reduceCounts[1])
    L1hit = 1 - float(reduceCounts[0])/float(reduceCounts[2]+reduceCounts[1])
    print '\tFLOPS/s: %e' % FLOPSS
    print '\tArithmetic Intensity: %.3f' % AI
    print '\tL1 hit rate: %3f' % L1hit

#=======================
#  Computer L2-norm
#=======================
comm.Barrier()
v, p = solution.split()
if calculateError:
  velocity = Expression(("2*pi*cos(2*pi*x[0])*sin(2*pi*x[1])*sin(2*pi*x[2])","2*pi*sin(2*pi*x[0])*cos(2*pi*x[1])*sin(2*pi*x[2])","2*pi*sin(2*pi*x[0])*sin(2*pi*x[1])*cos(2*pi*x[2])"))
  exact = Function(V)
  exact.project(velocity)
  L2_square = dot(v - exact,v - exact) * dx
  L2_error = sqrt(assemble(L2_square))
  if rank == 0:
    print '\tL2 error norm = ', L2_error

#========================
#  Output solutions
#========================
if printSolution:
  File("Figures/RT0_pressure.pvd") << p
  File("Figures/RT0_velocity.pvd") << v
