from firedrake import *
from firedrake.petsc import PETSc
from mpi4py import MPI
import pypapi
import numpy as np

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
#events = np.asarray([pypapi.Event.L1_DCM, pypapi.Event.LD_INS, pypapi.Event.SR_INS, pypapi.Event.FP_OPS],
#                    dtype=pypapi.EventType)

mesh = UnitCubeMesh(20, 20, 20)
V = VectorFunctionSpace(mesh, "CG", 2)
Q = FunctionSpace(mesh, "DG", 0)
W = V * Q

v, p = TrialFunctions(W)
w, q = TestFunctions(W)

f = Function(Q)
f.interpolate(Expression("12*pi*pi*sin(pi*x[0]*2)*sin(pi*x[1]*2)*sin(2*pi*x[2])"))
a = (dot(v, w) - p*div(w) - div(v)*q)*dx - 0.5*dot(v+grad(p),w+grad(q))*dx
n = FacetNormal(mesh)
L = f*q*dx

solution = Function(W)
bc1 = DirichletBC(Q, 0, 1, method="geometric")
bc2 = DirichletBC(Q, 0, 2, method="geometric")
bc3 = DirichletBC(Q, 0, 3, method="geometric")
bc4 = DirichletBC(Q, 0, 4, method="geometric")
bc5 = DirichletBC(Q, 0, 5, method="geometric")
bc6 = DirichletBC(Q, 0, 6, method="geometric")

problem_selfp = LinearVariationalProblem(a, L, solution, bcs=[bc1,bc2,bc3,bc4,bc5,bc6])

selfp_parameters = {
    'ksp_type': 'gmres',
    #'ksp_monitor_true_residual': True,
    # Upper Schur factorisation
    # Precondition S = D - C A^{-1} B with Sp = D - C diag(A)^{-1} B
    'pc_type': 'fieldsplit',
    'pc_fieldsplit_type': 'schur',
    'pc_fieldsplit_schur_fact_type': 'upper',
    'pc_fieldsplit_schur_precondition': 'selfp',
    # Approximately invert A with a single application of
    # process-block ILU(0)
    'fieldsplit_0_ksp_type': 'preonly',
    'fieldsplit_0_pc_type': 'bjacobi',
    'fieldsplit_0_sub_pc_type': 'ilu',
    # Approximately invert S with a single V-cycle on Sp
    # This means we never apply the action of the schur complement and
    # let the outer GMRES iteration fix things up
    'fieldsplit_1_ksp_type': 'preonly',
    'fieldsplit_1_pc_type': 'hypre'
    }
solver = LinearVariationalSolver(problem_selfp, options_prefix="selfp_",
                                 solver_parameters=selfp_parameters)

#========================
#  Solve
#========================
#reduceCounts = np.zeros(4, dtype=pypapi.CountType)
#counts = np.zeros(4, dtype=pypapi.CountType)
#pypapi.start_counters(events)
initialTime = pypapi.get_real_usec()
print 'solving... '
with PETSc.Log.Stage("selfp"):
    solver.solve()
print 'done'
endTime = pypapi.get_real_usec()
#pypapi.stop_counters(counts)

#=======================
#  Performance metrics
#=======================
#comm.Allreduce(counts,reduceCounts,op=MPI.SUM)
if rank == 0:
#  reduceCounts = reduceCounts#/size
  elapsedTime = float(endTime-initialTime)*1e-6
#  FLOPSS = reduceCounts[3]/elapsedTime
#  AI = float(reduceCounts[3])/float(reduceCounts[2]+reduceCounts[1])
#  L1hit = 1 - float(reduceCounts[0])/float(reduceCounts[2]+reduceCounts[1])
  print 'Solver time: %e' % elapsedTime 
#  print 'FLOPS/s: %e' % FLOPSS
#  print 'Arithmetic Intensity: %.2f' % AI
#  print 'L1 hit rate: %2f' % L1hit

comm.Barrier()
v, p = solution.split()
velocity = Expression(("2*pi*cos(2*pi*x[0])*sin(2*pi*x[1])*sin(2*pi*x[2])","2*pi*sin(2*pi*x[0])*cos(2*pi*x[1])*sin(2*pi*x[2])","2*pi*sin(2*pi*x[0])*sin(2*pi*x[1])*cos(2*pi*x[2])"))
exact = Function(V)
exact.project(velocity)
L2_square = dot(v - exact,v - exact) * dx
L2_error = sqrt(assemble(L2_square))
LMB_error = assemble((div(v)-f)*dx)
if rank == 0:
  print 'L2 error norm = ', L2_error
  print 'mass balance error = ', LMB_error

#========================
#  Output solutions
#========================
File("Figures/RT0_pressure.pvd") << p
File("Figures/RT0_velocity.pvd") << v
