
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np

##################################################

# Variational problem for phi equation
def solver_phi(phi1, phi0_5, phi0, eta0, mu0_5, step_b, etaR, phi, v, dt, g, solvers_print):

	# Note: For the SV method, give      0.5*dt	 instead of dt.
	# 	-> If in addition give 	     	 phi0_5	  	"		phi1,	then this is the first half-step of the SV.
	#	-> If in addition give	     	 phi0_5	  	"		phi0,
	#		          	  and	     	 eta1  	  	"		eta0,	then this is the second half-step of the SV.
    aphi = ( v*phi )*dx
    Lphi = ( v*phi0 - dt*g*v*(eta0-etaR) + step_b*v*mu0_5 )*dx

    phi_problem = LinearVariationalProblem(aphi, Lphi, phi1)
    phi_solver = LinearVariationalSolver(phi_problem, solver_parameters=solvers_print)

    return phi_solver;

# Variational problem for eta equation
def solver_eta(eta1, eta0_5, eta0, phi0_5, eta, v, dt, Hb, H0, dR_dt, solvers_print):

	# Note: For the SV method, give      eta1	   instead of	eta0_5.
	aeta = ( v*eta )*dx
	Leta = ( v*eta0 + dt*Hb*inner(grad(v),grad(phi0_5)) )*dx + dt*H0*dR_dt*v*ds(1)      # ds_v(1) in extruded meshes

	eta_problem = LinearVariationalProblem(aeta, Leta, eta1)
	eta_solver = LinearVariationalSolver(eta_problem, solver_parameters=solvers_print)
    
	return eta_solver;

##################################################

# Variational problem for lambda equation
def solver_mu(eta0, phi0, Z0, W0, mu0_5, mu, v, step_b, dt, Hb, g, rho, Mass):
    
    # Define linear and bilinear forms
    M_form  = mu*v*dx
    Mb_form = step_b*mu*v*dx
    A_form  = Hb*inner(grad(mu),grad(v))*dx
    Ab_form = Hb*step_b*inner(grad(mu),grad(v))*dx
    Q_form = step_b*v*dx

    # Get reference to PETSc Matrices
    M  = assemble(M_form).M.handle
    Mb = assemble(Mb_form).M.handle
    A  = assemble(A_form).M.handle
    Ab = assemble(Ab_form).M.handle

    # Get reference to PETSc Vector Q (take a copy here)
    Qf = assemble(Q_form)
    with Qf.dat.vec_ro as vec:
        Q = vec

    # Find left hand side matrix (M_b*M^{-1})*A*(M^{-1}*M_b)
    eye = PETSc.Mat().createDense(M.getSizes())     # Create indentity matrix 
    eye.setUp()
    diag = eye.createVecLeft()
    diag.set(1.0)
    eye.setDiagonal(diag, addv=PETSc.InsertMode.INSERT_VALUES)

    Minv = eye.duplicate()              # Fill matrix Minv (to hold the solution) with identity
    r, c = M.getOrdering("natural")     # Factor the matrix to be inverted
    M.factorLU(r, c)
    M.matSolve(eye, Minv)               # Solve system M*Minv = eye

    Minv_Mb = Minv.matMult(Mb)          # Multiply matrices
    Mb_Minv = Mb.matMult(Minv)
    Mb_Minv_A = Mb_Minv.matMult(A)
    Mb_Minv_A_Minv_Mb = Mb_Minv_A.matMult(Minv_Mb)
    Ab = Mb_Minv_A_Minv_Mb

    ###
    # Build the matrix free operator B
    class MatrixFreeB(object):
        def __init__(self, Ab, Q1, Q2):
            self.Ab = Ab
            self.Q1 = Q1
            self.Q2 = Q2

        # Compute y = alpha Q1 + y = Q2^T*Q1 + A_b*x
        def mult(self, mat, x, y):
            self.Ab.mult(x, y)
            alpha = self.Q2.dot(x) 
            y.axpy(alpha, self.Q1)
    ###

    # Define MatrixFreeB class operator B
    B = PETSc.Mat().create()
    B.setSizes(*Ab.getSizes())
    B.setType(B.Type.PYTHON)
    B.setPythonContext(MatrixFreeB(Ab, rho/Mass*Q, Q))
    B.setUp()
    '''
    # Print LHS 'matrix' B to inspect its values
    ctx = B.getPythonContext()
    dense_Ab = ctx.Ab[:, :]
    Q_array = ctx.Q2[:]
    dense_B = dense_Ab + rho/Mass*np.outer(Q_array, Q_array)
    print 'LHS matrix Mb_Minv_A_Minv_Mb + Q*Q^T \n------------\n' + str( dense_B ) + ' \n'
    '''
    # For small problems, we don't need a preconditioner at all
    mu_solver = PETSc.KSP().create()
    mu_solver.setOperators(B)
    opts = PETSc.Options()
    opts["pc_type"] = "none"
    mu_solver.setUp()
    mu_solver.setFromOptions()

    # Define RHS
    rhs = find_rhs(Mb_Minv_A, dt, g, eta0, phi0, v, step_b, Z0, W0);

    # Solve the linear system
    with rhs.dat.vec_ro as b:
        with mu0_5.dat.vec as x:
            mu_solver.solve(b, x)

    return (mu_solver, mu0_5, Mb_Minv_A);

# Calculate RHS for linear system for Lagrange multiplier
def find_rhs(Mb_Minv_A, dt, g, eta0, phi0, v, step_b, Z0, W0):

    rhs1_form = ( -v*step_b*eta0/dt + v*step_b*(Z0/dt + W0) )*dx
    rhs2 = assemble(0.5*dt*g*eta0-phi0)
    rhs = Function(rhs2.function_space())
    with rhs2.dat.vec_ro as b:
        with rhs.dat.vec as x:
            Mb_Minv_A.mult(b, x)

    rhs += assemble(rhs1_form)

    return rhs;

##################################################

# Variational problem for W equation
def solver_W(W1, W0, mu0_5, step_b, rho, Mass):

	# Note: For the SV method, give    	 W0_5	  	"		W1,		then this is the first half-step of the SV.
	#						   give	     W0_5	  	"		W0,		then this is the second half-step of the SV.
	W1.assign(W0 - rho/Mass*assemble(step_b*mu0_5*dx))

	return;

# Variational problem for Z equation
def solver_Z(Z1, Z0, W0_5, dt):

	Z1.assign(Z0 + dt*W0_5)

	return;

##################################################

# 2nd-order Stormer-Verlet solvers
def solvers_SV(rhs, mu0_5, mu_solver0_5, phi_solver0_5, eta_solver1, phi_solver1, eta0, phi0, Z0, W0, eta1, phi1, Z1, W1, W0_5, step_b, rho, Mass, dt):
	
    with rhs.dat.vec_ro as b:
        with mu0_5.dat.vec as x:
            mu_solver0_5.solve(b, x)    			# Solve for mu0_5

    phi_solver0_5.solve()							# Solve for phi0_5
    eta_solver1.solve()								# Solve for eta1
    solver_W(W0_5, W0, mu0_5, step_b, rho, Mass);   # Solve for W0_5
    solver_Z(Z1, Z0, W0_5, dt);						# Solve for Z1
    phi_solver1.solve()								# Solve for phi1
    solver_W(W1, W0_5, mu0_5, step_b, rho, Mass);	# Solve for W1 

    phi0.assign(phi1)
    eta0.assign(eta1)
    W0.assign(W1)
    Z0.assign(Z1)

    return (phi0, eta0, W0, Z0);

##################################################

