# User parameters
import sys
seed = int(sys.argv[2])
FEType = sys.argv[1]
useCounters = False
calculateError = False
printSolution = False

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
from mpi4py import MPI
import pypapi
import numpy as np
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
if useCounters:
 #   events = np.asarray([ \
      #pypapi.Event.L1_TCM, \
      #pypapi.Event.LD_INS, \
      #pypapi.Event.TOT_CYC, \
      #pypapi.Event.SR_INS, \
 #     pypapi.Event.FP_OPS \
 #     ],dtype=pypapi.EventType)
    indxL1 = 0
    indxL3 = 0
    indxLD = 1
    indxSR = 2
    indxTOTCYC = 2
    indxFPOPS = 3
if rank == 0:
  print 'Discretization: %s' % FEType
#========================
#  Discretization 
#========================
mesh = UnitCubeMesh(seed, seed, seed)
if (FEType == 'RT' or FEType == 'BDM'):
  V = FunctionSpace(mesh, FEType, 1)
  Q = FunctionSpace(mesh,"DG",0)
elif (FEType == 'TH'):
  V = VectorFunctionSpace(mesh,"CG",2)
  Q = FunctionSpace(mesh,"CG",1)
else:
  V = VectorFunctionSpace(mesh,"CG",1)
  Q = FunctionSpace(mesh,"CG", 1)
W = V * Q
v, p = TrialFunctions(W)
w, q = TestFunctions(W)
f = Function(Q)
f.interpolate(Expression("12*pi*pi*sin(pi*x[0]*2)*sin(pi*x[1]*2)*sin(2*pi*x[2])"))
if (FEType == 'BDM' or FEType == 'RT' or FEType == 'TH'):
  a = (dot(v, w) - p*div(w) - div(v)*q)*dx
  L = f*q*dx
elif (FEType == 'VMS'):
  a = (dot(v, w) - p*div(w) - div(v)*q - 0.5*dot(v+grad(p),w+grad(q)))*dx
  L = f*q*dx
elif (FEType == 'LS'):
  a = dot(v+grad(p),w+grad(q))*dx + div(v)*div(w)*dx
  L = f*div(w)*dx
 
#========================
#  Parameters
#========================
selfp_parameters = {
    'ksp_type': 'cg',
    #'ksp_monitor_true_residual': True,
    # Upper Schur factorisation
    # Precondition S = D - C A^{-1} B with Sp = D - C diag(A)^{-1} B
    'pc_type': 'fieldsplit',
    'pc_fieldsplit_type': 'schur',
    'pc_fieldsplit_schur_fact_type': 'upper',
    'pc_fieldsplit_schur_precondition': 'selfp',
    # Approximately invert A with a single application of
    # process-block ILU(0)
    'fieldsplit_0_ksp_type': 'cg',
    'fieldsplit_0_pc_type': 'jacobi',
    #'fieldsplit_0_sub_pc_type': 'ilu',
    # Approximately invert S with a single V-cycle on Sp
    # This means we never apply the action of the schur complement and
    # let the outer GMRES iteration fix things up
    'fieldsplit_1_ksp_type': 'preonly',
    'fieldsplit_1_pc_type': 'gamg'
    #'fieldsplit_1_ksp_convergence_test': 'skip'
    }
cg_parameters = {
    'ksp_type':'cg',
    'ksp_monitor_true_residual': True,
    'pc_type':'bjacobi'
    #'pc_hypre_type':'boomeramg'
    }

#========================
#  Solve problem
#========================
solution = Function(W)
if useCounters:
  reduceCounts = np.zeros(4, dtype=pypapi.CountType)
  counts = np.zeros(4, dtype=pypapi.CountType)
  pypapi.start_counters(events)
initialTime = pypapi.get_real_usec()
if rank == 0:
  print 'MPI processes %d: solving... ' % size
with PETSc.Log.Stage("FEM"):
  if FEType == 'LS':
    #bc1 = DirichletBC(W.sub(0).sub(0), Expression("-2*pi*cos(2*pi*x[0])*sin(2*pi*x[1])*sin(2*pi*x[2])"), (1,2))
    #bc2 = DirichletBC(W.sub(0).sub(1), Expression("-2*pi*sin(2*pi*x[0])*cos(2*pi*x[1])*sin(2*pi*x[2])"), (3,4))
    #bc3 = DirichletBC(W.sub(0).sub(2), Expression("-2*pi*sin(2*pi*x[0])*sin(2*pi*x[1])*cos(2*pi*x[2])"), (5,6))
    #bc4 = DirichletBC(W.sub(1),0,(1,2,3,4,5,6))
    #bcs = [bc1,bc2,bc3]
    A = assemble(a)
    if rank == 0:
      print A.M.handle.getSizes()
    b = assemble(L)
    #v_basis = VectorSpaceBasis(constant=True)
    #nullspace = MixedVectorSpaceBasis(W, [v_basis, W.sub(1)])
    solver = LinearSolver(A,solver_parameters=selfp_parameters,options_prefix="cg_")
  else:
    A = assemble(a)
    if rank == 0:
      print A.M.handle.getSizes()
    b = assemble(L)
    solver = LinearSolver(A,solver_parameters=selfp_parameters,options_prefix="selfp_")
  solver.solve(solution,b)
endTime = pypapi.get_real_usec()
if useCounters:
  pypapi.stop_counters(counts)

#=======================
#  Performance metrics
#=======================
if useCounters:
  comm.Allreduce(counts,reduceCounts,op=MPI.SUM)
if rank == 0:
  elapsedTime = float(endTime-initialTime)*1e-6
  print '\tSolver time: %e' % elapsedTime 
  print '\tSolver iterations: %d' % solver.ksp.getIterationNumber()
  #if useCounters:
    #FLOPSS = reduceCounts[indxFPOPS]/elapsedTime
    #L1_hit = (1 - float(reduceCounts[indxL1])/float(reduceCounts[indxLD]+reduceCounts[indxSR]))
    #L1_bw = 1000000*166400*float(indxL1)/float(indxTOTCYC)
    #L3_bw = 1000000*166400*float(indxL2)/float(indxTOTCYC)
    #L1_bytes = L1_bw*elapsedTime
    #L3_bytes = L3_bw*elapsedTime
    #INS_AI = float(reduceCounts[indxFPOPS])/float(reduceCounts[indxLD]*8+reduceCounts[indxSR]*8)
    #L1_AI = float(reduceCounts[indxFPOPS])/float(L1_bytes)
    #L3_AI = float(FLOPSS)/float(L3_bw)
    #print '\tPAPI_LD_INS: %e' % reduceCounts[indxLD]
    #print '\tPAPI_SR_INS: %e' % reduceCounts[indxSR]
    #print '\tPAPI_L1_TCM: %e' % reduceCounts[indxL1]
    #print '\tPAPI_L3_TCM: %e' % reduceCounts[indxL3]
    #print '\tPAPI_TOT_CYC: %e' % reduceCounts[indxTOTCYC]
    #print '\tL1 bytes: %e' % L1_bytes
    #print '\tL3 bytes: %e' % L3_bytes
    #print '\tFLOPS/s: %e' % FLOPSS
    #print '\tArithmetic Intensity (INS): %.4f' % INS_AI
    #print '\tL1 hit rate: %.4f' % L1_hit
    #print '\tArithmetic Intensity (L1): %.3f' % L1_AI
    #print '\tArithmetic Intensity (L3): %.3f' % L3_AI

#=======================
#  Computer L2-norm
#=======================
comm.Barrier()
v, p = solution.split()
if calculateError:
  velocity = Expression(("2*pi*cos(2*pi*x[0])*sin(2*pi*x[1])*sin(2*pi*x[2])","2*pi*sin(2*pi*x[0])*cos(2*pi*x[1])*sin(2*pi*x[2])","2*pi*sin(2*pi*x[0])*sin(2*pi*x[1])*cos(2*pi*x[2])"))
  exact = Function(V)
  exact.project(velocity)
  L2_square = dot(v - exact,v - exact) * dx
  L2_error = sqrt(assemble(L2_square))
  if rank == 0:
    print '\tL2 error norm = ', L2_error

#========================
#  Output solutions
#========================
if printSolution:
  File("Figures/RT0_pressure.pvd") << p
  File("Figures/RT0_velocity.pvd") << v
