from firedrake import *
from firedrake.petsc import PETSc
import pypapi
import numpy as np
# Mixed poisson
# Homogeneous dirichlet BCs on all walls
# Forcing
# f = -\pi^2 (4 cos(\pi x) - 5 cos(\pi x/2) + 2) sin(\pi y)
# With exact solution
# sin(\pi x) tan(\pi x/4) sin(\pi y)

events = np.asarray([pypapi.Event.LD_INS, pypapi.Event.SR_INS, pypapi.Event.FP_OPS],
                    dtype=pypapi.EventType)

mesh = UnitSquareMesh(500, 500)

V = FunctionSpace(mesh, "RT", 1)
Q = FunctionSpace(mesh, "DG", 0)

W = V*Q

sigma, u = TrialFunctions(W)
tau, v = TestFunctions(W)

f = Function(Q)

f.interpolate(Expression("-0.5*pi*pi*(4*cos(pi*x[0]) - 5*cos(pi*x[0]*0.5) + 2)*sin(pi*x[1])"))
a = (dot(sigma, tau) + div(tau)*u + div(sigma)*v)*dx

n = FacetNormal(mesh)

L = -f*v*dx

exact = Function(Q)

exact.interpolate(Expression("sin(pi*x[0])*tan(pi*x[0]*0.25)*sin(pi*x[1])"))


solution = Function(W)

# To see how we're doing compared to an exact inverse
#problem_direct = LinearVariationalProblem(a, L, solution,
                                          # Build monolithic, rather
                                          # than block, matrix for the
                                          # spaces, so that MUMPS can
                                          # see and invert the whole thing
#                                          nest=False)

#direct_parameters = {
#    'ksp_type': 'preonly',
#    'pc_type': 'lu',
#    'pc_factor_mat_solver_package': 'mumps'
#    }

#solver_direct = LinearVariationalSolver(problem_direct,
#                                        options_prefix="direct_",
#                                        solver_parameters=direct_parameters)


problem_selfp = LinearVariationalProblem(a, L, solution)


selfp_parameters = {
    'ksp_type': 'gmres',
    'ksp_monitor_true_residual': True,
    # Upper Schur factorisation
    # Precondition S = D - C A^{-1} B with Sp = D - C diag(A)^{-1} B
    'pc_type': 'fieldsplit',
    'pc_fieldsplit_type': 'schur',
    'pc_fieldsplit_schur_fact_type': 'upper',
    'pc_fieldsplit_schur_precondition': 'selfp',
    # Approximately invert A with a single application of
    # process-block ILU(0)
    'fieldsplit_0_ksp_type': 'preonly',
    'fieldsplit_0_pc_type': 'bjacobi',
    'fieldsplit_0_sub_pc_type': 'ilu',
    # Approximately invert S with a single V-cycle on Sp
    # This means we never apply the action of the schur complement and
    # let the outer GMRES iteration fix things up
    'fieldsplit_1_ksp_type': 'preonly',
    'fieldsplit_1_pc_type': 'hypre'
    }
solver = LinearVariationalSolver(problem_selfp, options_prefix="selfp_",
                                 solver_parameters=selfp_parameters)


#alpha = Constant(4.0)
#gamma = Constant(8.0)
#h = CellSize(mesh)
#h_avg = (h('+') + h('-'))/2
#a_dg = dot(sigma, tau)*dx  + \
#       dot(grad(v), grad(u))*dx \
#       - dot(avg(grad(v)), jump(u, n))*dS \
#       - dot(jump(v, n), avg(grad(u)))*dS \
#       + alpha/h_avg * dot(jump(v, n), jump(u, n))*dS \
#       - dot(grad(v), u*n)*ds \
#       - dot(v*n, grad(u))*ds \
#       + (gamma/h)*dot(v, u)*ds

# aP is an optional form that is assembled and then used by PETSc to
# build the preconditioning matrices
#problem_ip = LinearVariationalProblem(a, L, solution,
#                                      aP=a_dg)

#ip_parameters = {
#    'ksp_type': 'gmres',
#    'ksp_monitor_true_residual': True,
    # Upper Schur factorisation
    # Precondition S = D - C A^{-1} B with an interior penalty DG
    # Laplacian (since it's spectrally equivalent)
#    'pc_type': 'fieldsplit',
#    'pc_fieldsplit_type': 'schur',
#    'pc_fieldsplit_schur_fact_type': 'upper',
    # a11 is the IP-DG block from aP
#    'pc_fieldsplit_schur_precondition': 'a11',
    # Approximately invert A with a single application of
    # process-block ILU(0)
#    'fieldsplit_0_ksp_type': 'preonly',
#    'fieldsplit_0_pc_type': 'bjacobi',
#    'fieldsplit_0_sub_pc_type': 'ilu',
    # Approximately invert S with a single V-cycle on Sp
    # This means we never apply the action of the schur complement and
    # let the outer GMRES iteration fix things up
#    'fieldsplit_1_ksp_type': 'preonly',
#    'fieldsplit_1_pc_type': 'hypre'
#    }
    
#solver_ip = LinearVariationalSolver(problem_ip,
#                                    options_prefix="ip_",
#                                    solver_parameters=ip_parameters)

# If boundary conditions are such that div-div is a norm, rather than
# a semi-norm don't need the mass term in the H(div) part
#a_riesz = u*v*dx + div(sigma)*div(tau)*dx + dot(sigma, tau)*dx

#problem_riesz = LinearVariationalProblem(a, L, solution,
#                                         aP=a_riesz)

#riesz_parameters = {
#    'ksp_type': 'gmres',
#    'ksp_monitor_true_residual': True,
#    # Additive fieldsplit
#    # Precondition by the H(div)-L2 inner product
#    'pc_type': 'fieldsplit',
#    'pc_fieldsplit_type': 'additive',
#    # Invert H(div) part with direct solver
#    # Could hook in to HYPRE's auxiliary space ADS multigrid (PETSc
#    # support is being developed by Stefano Zampini, but we'd need to
#    # add some hooks)
#    'fieldsplit_0_ksp_type': 'preonly',
#    'fieldsplit_0_pc_type': 'lu',
#    'fieldsplit_0_pc_factor_mat_solver_package': 'mumps',
    # 1-1 block is just a DG mass matrix, so invert exactly with
    # process-block ILU(0)
#    'fieldsplit_1_ksp_type': 'preonly',
#    'fieldsplit_1_pc_type': 'bjacobi',
#    'fieldsplit_1_sub_pc_type': 'ilu',
#    }
#solver_riesz = LinearVariationalSolver(problem_riesz,
#                                       options_prefix="riesz_",
#                                       solver_parameters=riesz_parameters)
counts = np.zeros(3, dtype=pypapi.CountType)

pypapi.start_counters(events)
with PETSc.Log.Stage("selfp"):
    solver.solve()
pypapi.stop_counters(counts)

print 'Total FLOPS: %e' % counts[2]
print float(counts[2])/float(counts[0]+counts[1])

sigma, u = solution.split()
print 'norm = %f' % norm(assemble(u - exact))

#plot(u,title="Pressure")
#interactive()

#solution.assign(0)
#with PETSc.Log.Stage("dg_ip"):
#    solver_ip.solve()

#sigma, u = solution.split()

#print norm(assemble(u - exact))



#solution.assign(0)
#with PETSc.Log.Stage("riesz"):
#    solver_riesz.solve()

#sigma, u = solution.split()

#print norm(assemble(u - exact))

#solution.assign(0)
#with PETSc.Log.Stage("direct"):
#    solver_direct.solve()

#sigma, u = solution.split()

#print norm(assemble(u - exact))
