from firedrake import *
from math import log

# lists to hold L2 errors
err_u = []
err_p = []
err_div = []

# Order of velocity approximation depends on using RT or BDM
hdiv_space = "BDM"

if hdiv_space == "RT":
    order = 3
    if COMM_WORLD.rank == 0:
        print "Using RT space, order = ", order
else: # BDM
    order = 2
    if COMM_WORLD.rank == 0:
        print "Using BDM space, order = ", order

# Mesh size
N = [8, 16, 32, 64]

for nn in N:

    mesh = UnitSquareMesh(nn, nn)

    HDIV = FunctionSpace(mesh, hdiv_space, order)
    L2 = FunctionSpace(mesh, "DG", order-1)
    MIXED_FS = HDIV * L2

    x, y = SpatialCoordinate(mesh)
    uxexpr = sin(pi*x)*sin(pi*y)
    uyexpr = cos(pi*x)*cos(pi*y)
    ppexpr = sin(pi*x)*cos(pi*y)
    srcxexpr = 2.0*pi*pi*sin(pi*x)*sin(pi*y) + pi*cos(pi*x)*cos(pi*y)
    srcyexpr = 2.0*pi*pi*cos(pi*x)*cos(pi*y) - pi*sin(pi*x)*sin(pi*y)

    u_exact = Function(HDIV).project(as_vector([uxexpr, uyexpr]))
    p_exact = Function(L2).project(ppexpr)
    source  = Function(HDIV).project(as_vector([srcxexpr, srcyexpr]))

    (u, p) = TrialFunctions(MIXED_FS)
    (v, q) = TestFunctions(MIXED_FS)

    n = FacetNormal(mesh)
    h = CellSize(mesh)
    sigma = 10.0

    a = inner(grad(u), grad(v))*dx - p*div(v)*dx - q*div(u)*dx

    a += (- inner(avg(grad(u)), outer(jump(v), n("+")))*dS
          - inner(avg(grad(v)), outer(jump(u), n("+")))*dS
          + (sigma/avg(h))*inner(jump(u), jump(v))*dS )

    a += (- inner(grad(u), outer(v, n))*ds
          - inner(grad(v), outer(u, n))*ds
          + (sigma/h)*inner(u, v)*ds )

    L = ( inner(source, v)*dx
          + (sigma/h)*inner(u_exact, v)*ds
          - inner(grad(v), outer(u_exact, n))*ds)

    # left = 1 
    # right = 2 
    # bottom = 3
    # top = 4
    bcfunc_left   = Function(HDIV).project(as_vector([u_exact[0], 0]))
    bcfunc_right  = Function(HDIV).project(as_vector([u_exact[0], 0]))
    bcfunc_bottom = Function(HDIV).project(as_vector([0, u_exact[1]]))
    bcfunc_top    = Function(HDIV).project(as_vector([0, u_exact[1]]))
    bcs = [DirichletBC(MIXED_FS.sub(0), bcfunc_left,   1),
           DirichletBC(MIXED_FS.sub(0), bcfunc_right,  2),
           DirichletBC(MIXED_FS.sub(0), bcfunc_bottom, 3),
           DirichletBC(MIXED_FS.sub(0), bcfunc_top,    4)]

    UP = Function(MIXED_FS)

    nullspace = MixedVectorSpaceBasis(
        MIXED_FS, [MIXED_FS.sub(0), VectorSpaceBasis(constant=True)])

    parameters = {
        "mat_type": "matfree",

        "ksp_type": "fgmres",
        "ksp_max_it": "500",
        "ksp_atol": "1.e-16",
        "ksp_rtol": "1.e-11",
        "ksp_monitor_true_residual": True,
        "ksp_view": False,
        "pc_type": "fieldsplit",
        "pc_fieldsplit_type": "multiplicative",
        
        "fieldsplit_0_ksp_type": "preonly",
        "fieldsplit_0_pc_type": "python",
        "fieldsplit_0_pc_python_type": "firedrake.AssembledPC",
        "fieldsplit_0_assembled_pc_type": "lu",
        "fieldsplit_0_assembled_pc_factor_mat_solver_package": "mumps",

        "fieldsplit_1_ksp_type": "preonly",
        "fieldsplit_1_pc_type": "python",
        "fieldsplit_1_pc_python_type": "firedrake.MassInvPC",
        "fieldsplit_1_Mp_ksp_type": "preonly",
        "fieldsplit_1_Mp_pc_type": "lu",
        "fieldsplit_1_Mp_pc_factor_mat_solver_package": "mumps"        
    }

    # Switch sign of pressure mass matrix
    mu = Constant(-1.0)
    appctx = {"mu": mu, "pressure_space": 1}

    UP.assign(0)
    solve(a == L, UP, bcs=bcs, nullspace=nullspace, solver_parameters=parameters, 
          appctx=appctx)
    
    u, p = UP.split()
    u.rename("Velocity")
    p.rename("Pressure")

    File("stokes.pvd").write(u, p)
    
    u_error = u - u_exact
    p_error = p - p_exact
    err_u.append(sqrt(abs(assemble(dot(u_error, u_error)*dx))))
    err_p.append(sqrt(abs(assemble(p_error*p_error*dx))))
    err_div.append(sqrt(assemble(div(u)**2*dx)))

i = 1

while i < len(err_u) :
    rate_u = log(err_u[i-1]/err_u[i])/log(2)
    rate_p = log(err_p[i-1]/err_p[i])/log(2)
    rate_d = log((err_div[i-1]+1.e-16)/(err_div[i]+1.e-11))/log(2)
    if COMM_WORLD.rank == 0:
        print("%2.2e %2.2f %2.2e %2.2f %2.2e %2.2f" % \
              (err_u[i], rate_u, err_p[i], rate_p, err_div[i], rate_d))
    i =  i+1
