from firedrake import *
from math import log

# Lists to hold L2 errors
err_u = []
err_p = []
err_div = []

# Mesh size
N = [8, 16, 32, 64, 128]

for nn in N:

    mesh = UnitSquareMesh(nn, nn)

    P2 = VectorFunctionSpace(mesh, "CG", 2)
    P1 = FunctionSpace(mesh, "CG", 1)
    MIXED_FS = P2 * P1

    x, y = SpatialCoordinate(mesh)
    uxexpr = sin(pi*x)*sin(pi*y)
    uyexpr = cos(pi*x)*cos(pi*y)
    ppexpr = sin(pi*x)*cos(pi*y)
    srcxexpr = 2.0*pi*pi*sin(pi*x)*sin(pi*y) + pi*cos(pi*x)*cos(pi*y)
    srcyexpr = 2.0*pi*pi*cos(pi*x)*cos(pi*y) - pi*sin(pi*x)*sin(pi*y)

    u_exact = Function(P2).project(as_vector([uxexpr, uyexpr]))
    p_exact = Function(P1).project(ppexpr)
    source  = Function(P2).project(as_vector([srcxexpr, srcyexpr]))

    (u, p) = TrialFunctions(MIXED_FS)
    (v, q) = TestFunctions(MIXED_FS)

    a = inner(grad(u), grad(v))*dx - p*div(v)*dx - div(u)*q*dx
    
    L = inner(source, v) * dx

    # Boundary conditions are defined on the velocity space  
    bcfunc = Function(P2).project(as_vector([u_exact[0], u_exact[1]]))
    bcs = [DirichletBC(MIXED_FS.sub(0), bcfunc, (1, 2, 3, 4))]

    UP = Function(MIXED_FS)

    nullspace = MixedVectorSpaceBasis(
        MIXED_FS, [MIXED_FS.sub(0), VectorSpaceBasis(constant=True)])

    parameters = {
        "mat_type": "matfree",

        "ksp_type": "fgmres",
        "ksp_max_it": "500",
        "ksp_atol": "1.e-16",
        "ksp_rtol": "1.e-10",
        "ksp_monitor_true_residual": True,
        "ksp_view": False,
        "pc_type": "fieldsplit",
        "pc_fieldsplit_type": "multiplicative",

        "fieldsplit_0_ksp_type": "preonly",
        "fieldsplit_0_pc_type": "python",
        "fieldsplit_0_pc_python_type": "firedrake.AssembledPC",
        "fieldsplit_0_assembled_pc_type": "lu",
        "fieldsplit_0_assembled_pc_factor_mat_solver_package": "mumps",

        "fieldsplit_1_ksp_type": "preonly",
        "fieldsplit_1_pc_type": "python",
        "fieldsplit_1_pc_python_type": "firedrake.MassInvPC",
        "fieldsplit_1_Mp_ksp_type": "preonly",
        "fieldsplit_1_Mp_pc_type": "lu",
        "fieldsplit_1_Mp_pc_factor_mat_solver_package": "mumps"
    }

    # Switch sign of pressure mass matrix
    mu = Constant(-1.0)
    appctx = {"mu": mu, "pressure_space": 1}

    UP.assign(0)
    solve(a == L, UP, bcs=bcs, nullspace=nullspace, solver_parameters=parameters, 
          appctx=appctx)

    u, p = UP.split()
    u.rename("Velocity")
    p.rename("Pressure")
    File("stokes.pvd").write(u, p)
    
    u_error = u - u_exact
    p_error = p - p_exact
    err_u.append(sqrt(abs(assemble(dot(u_error, u_error)*dx))))
    err_p.append(sqrt(abs(assemble(p_error*p_error*dx))))
    err_div.append(sqrt(assemble(div(u)**2*dx)))

i = 1
while i < len(err_u) :
    rate_u = log(err_u[i-1]/err_u[i])/log(2)
    rate_p = log(err_p[i-1]/err_p[i])/log(2)
    rate_d = log((err_div[i-1]+1.e-16)/(err_div[i]+1.e-11))/log(2)
    if COMM_WORLD.rank == 0:
        print("%2.2e %2.2f %2.2e %2.2f %2.2e %2.2f" % \
              (err_u[i], rate_u, err_p[i], rate_p, err_div[i], rate_d))
    i =  i+1
