#
# Solve linear system for laplace multiplier lambda (or mu=2/dt*lambda)
#
# Last update: 21.10.15
#
from firedrake import *
from parameters import *
from heavyside import *
from IC_etaR import *

# Parameter values
(L, H0, g, rho, Mass, a, Lp, Nx, Np, dt, T) = parameter_values();

# Create mesh 
mesh = IntervalMesh(Nx, L)
coords = mesh.coordinates

# Define functions
V = FunctionSpace(mesh, "CG", 1)
V_DG0 = FunctionSpace(mesh, "DG", 0)
V_DG1 = FunctionSpace(mesh, "DG", 1)

eta0 = Function(V)
phi0 = Function(V)
mu0_5 = Function(V)
step_b = Function(V_DG0)		# Heavyside step function: 0 in water part, 1 in buoy part.
Hb = Function(V_DG1)
etaR = Function(V)
Z0   = Constant(0.0)
W0   = Constant(0.0)

eta = TrialFunction(V)
mu  = TrialFunction(V)
v = TestFunction(V)
u = TestFunction(V)

# Buoy's hull shape
(Hb, H, Zbar) = def_Hb(Hb, coords.dat.data, L, Lp, H0, rho, Mass, a, Nx);

# Heavyside step function
step_b = def_step(step_b, coords.dat.data, Lp, 'True');
'''
# Initial condition for sluice gate
H0 = 0.1
x1 = 0.15
x2 = 0.16
h1 = 1.2*H0
etaR_expr = def_etaR(coords.dat.data, H0, H0, h1, x1, x2);
etaR.interpolate(etaR_expr)
eta0 = eta0_eq_etaR(eta,eta0,etaR,v);
'''
# Weak formulation

################################

# Define forms:
Q1_form = rho/Mass*step_b*v*dx
Q2_form = step_b*u*dx
A_form = step_b*inner(grad(mu),grad(v))*dx
rhs_form = ( -v*step_b*eta0/dt + Hb*inner(grad(0.5*dt*g*eta0-phi0),grad(v)) + v*step_b*(Z0/dt + W0) )*dx

# First, build the operator B:
from firedrake.petsc import PETSc

class MatrixFreeB(object):
    def __init__(self, A, Q1, Q2):
        self.A = A
        self.Q1 = Q1
        self.Q2 = Q2

    def mult(self, mat, x, y):
        "Compute y = B*x"
        # y = Ax
        self.A.mult(x, y)

    # alpha = Q2^T x
	alpha = self.Q2.dot(x)
    # y = alpha Q1 + y
	y.axpy(alpha, self.Q1)

# Get reference to PETSc Matrix
A = assemble(A_form).M.handle
Q1f = assemble(Q1_form)
Q2f = assemble(Q2_form)

# Get reference to PETSc Vectors for Q1 and Q2
# I take a copy here, so if Q1 and Q2 change, you'll need to do something different
with Q1f.dat.vec_ro as v:
    Q1 = v.copy()
with Q2f.dat.vec_ro as v:
    Q2 = v.copy()

B = PETSc.Mat().create()

B.setSizes(*A.getSizes())
B.setType(B.Type.PYTHON)
# Indicate that B should use the MatrixFreeB class to compute mult.
B.setPythonContext(MatrixFreeB(A, Q1, Q2))
B.setUp()

# For small problems, we don't need a preconditioner at all, so let's check this works:
mu_solver = PETSc.KSP().create()

mu_solver.setOperators(B)

opts = PETSc.Options()
opts["pc_type"] = "none"

mu_solver.setUp()
mu_solver.setFromOptions()

# Now let's solve the system.
rhs = assemble(rhs_form)

print rhs.dat.data

with rhs.dat.vec_ro as b:
    with mu0_5.dat.vec as x:
        mu_solver.solve(b, x)

################################

# Write data to files
mu_file = File("mu.pvd")
mu_file << mu0_5

