# User parameters
seed = 40

#========================
#  Initialize
#========================
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np
from firedrake import utils

#========================
#  Create mesh 
#========================
mesh = UnitSquareMesh(seed, seed)
V = VectorFunctionSpace(mesh, "CG", 1)
Q = FunctionSpace(mesh, "CG", 1)
W = V * Q
v, p = TrialFunctions(W)
w, q = TestFunctions(W)

#========================
#  Weak form
#========================
g = Function(V)
g.interpolate(Expression(("cos(pi*x[0])*sin(pi*x[1])+2*pi*cos(2*pi*x[0])*sin(2*pi*x[1])","-sin(pi*x[0])*cos(pi*x[1])+2*pi*sin(2*pi*x[0])*cos(2*pi*x[1])")))
a = dot(v+grad(p),w+grad(q))*dx + div(v)*div(w)*dx
L = dot(w+grad(q),g)*dx

#========================
#  Solver parameters
#========================
cg_parameters = {
    'ksp_type': 'cg',
    #'ksp_monitor_true_residual': True,
    'pc_type': 'bjacobi'
    }
#========================
#  Boundary conditions
#========================
class PointDirichletBC(DirichletBC):
  @utils.cached_property
  def nodes(self):
    # Find the array of coordinate values.
    x = self.function_space().mesh().coordinates.dat.data_ro
    # Find the location of the zero rows in that 
    return np.where(~x.any(axis=1))[0]

bc1 = DirichletBC(W.sub(0).sub(0), Expression("cos(pi*x[0])*sin(pi*x[1])"), (1,2))
bc2 = DirichletBC(W.sub(0).sub(1), Expression("-sin(pi*x[0])*cos(pi*x[1])"), (3,4))
#bc_all = [bc1,bc2]
bc3 = PointDirichletBC(W.sub(1),0,0)
bc_all = [bc1,bc2,bc3]
#nullspace = MixedVectorSpaceBasis(W,[W[0],VectorSpaceBasis(constant=True)])
#========================
#  Solve problem
#========================
solution = Function(W)
A = assemble(a,bcs=bc_all)
b = assemble(L,bcs=bc_all)
#solver = LinearSolver(A,solver_parameters=cg_parameters,nullspace=nullspace,options_prefix="cg_")
solver = LinearSolver(A,solver_parameters=cg_parameters,options_prefix="cg_")
solver.solve(solution,b)

#========================
#  Output solutions
#========================
v,p = solution.split()
File("Figures/LS_pressure.pvd") << p
File("Figures/LS_velocity.pvd") << v
