#==========================================================
#
#  Steady-state diffusion equation with Dirichlet conditions
#  Non-negative formulation as mixed complementarity problem (MCP)
#
#  Run as:
#    python MCP_diffusion.py <xseed> <yseed> <opt>
#
#  <xseed> = number of elements in x direction
#  <yseed> = number of elements in y direction
#  <opt> = Use standard linear solver (0) 
#          SSILS solver (1)
#
#==========================================================

from firedrake import *
from firedrake.petsc import PETSc
from ufl import log
log.set_level(log.CRITICAL)
parameters["assembly_cache"]["enabled"] = True
import numpy as np
import sys
op2.init(log_level='ERROR')
xseed = int(sys.argv[1])
yseed = int(sys.argv[2])
opt = int(sys.argv[3])

#=================
#  Create mesh
#=================
mesh = UnitSquareMesh(xseed, yseed, quadrilateral=True)
V = FunctionSpace(mesh, 'CG', 1)
Q = TensorFunctionSpace(mesh, 'CG', 1)
u = TrialFunction(V)
v = TestFunction(V)

# Solution
sol = Function(V, name="solution")

#=====================
#  Medium properties
#=====================
D = Function(Q)
f = Constant(0.0)
alpha = Constant(0.0)
d1 = 1.0
d2 = 0.0001
theta = pi / 6.0
co = cos(theta)
si = sin(theta)
D = interpolate(Expression(("co*co*d1+si*si*d2", "-co*si*(d1-d2)",
                          "-co*si*(d1-d2)", "si*si*d1+co*co*d2"),
                         co=co,
                         si=si,
                         d1=d1,
                         d2=d2), Q)

#=====================
#  Variational form
#=====================
a = alpha * inner(u, v) * dx + inner(D * grad(u), grad(v)) * dx
L = v * f * dx

#=======================
#  Boundary conditions
#=======================
bc1 = DirichletBC(V, Constant(0.0), (1, 2, 4))
bc2 = DirichletBC(V, Expression('sin(pi*x[0])'), (3))
bcs = [bc1, bc2]

# RHS for optimization
tmp = Function(V)
for bc in bcs:
  bc.apply(tmp)
rhs_opt = assemble(action(a, tmp))

#========================
#  Optimization
#========================
if opt == 0:
  problem = LinearVariationalProblem(a, L, sol, bcs=bcs,constant_jacobian=False)
  solver = LinearVariationalSolver(problem, options_prefix="solver_")
else:
  A = assemble(a, bcs=bcs)
  b = assemble(L)
  lb = Function(V)
  ub = Function(V)
  ub.assign(1)
  taoSolver = PETSc.TAO().create(PETSc.COMM_WORLD)
  taoSolver.setOptionsPrefix("opt_")
  
  # Set variable bounds
  with lb.dat.vec_ro as lb_vec, ub.dat.vec_ro as ub_vec:
    taoSolver.setVariableBounds(lb_vec, ub_vec)
  
  # TRON solver
  def Hessian(tao, petsc_x, petsc_H, petsc_HP):
    pass
  def ObjGrad(tao, petsc_x, petsc_g, A=None, b=None):
    with b.dat.vec_ro as b_vec:
      A.M.handle.mult(petsc_x, petsc_g)
      xtHx = petsc_x.dot(petsc_g)
      xtf = petsc_x.dot(b_vec)
      petsc_g.axpy(-1.0, b_vec)
      return 0.5 * xtHx - xtf
  
  # Complementarity solver
  def formJac(tao, petsc_x, petsc_J, petsc_JP, J=None, b=None, a=None, bcs=None):
    petsc_J = assemble(a,bcs=bcs).M.handle
    petsc_JP = assemble(a,bcs=bcs).M.handle
  
  def formGrad(tao, petsc_x, petsc_g, A=None, b=None, a=None, bcs=None):
    A = assemble(a, bcs=bcs)
    with b.dat.vec_ro as b_vec:
      A.M.handle.mult(petsc_x, petsc_g)
      petsc_g.axpy(-1.0, b_vec)
            
  # Option 1 - TRON
  if opt == 1:
    taoSolver.setObjectiveGradient(ObjGrad, kargs={'A': A, 'b': b})
    taoSolver.setHessian(Hessian,A.M.handle)
    taoSolver.setType(PETSc.TAO.Type.TRON)
  
  # Option 2 - SSILS
  else:
    con = Function(V)
    J = assemble(a,bcs=bcs)
    with con.dat.vec as con_vec:
      taoSolver.setConstraints(formGrad, con_vec, kargs={'A': A, 'b': b, 'a': a, 'bcs': bcs})
    taoSolver.setJacobian(formJac,J.M.handle, kargs={'J': J, 'b': b, 'a': a, 'bcs': bcs})
    taoSolver.setType(PETSc.TAO.Type.SSILS)
  taoSolver.setFromOptions()

#========================
#  Semilinear solver
#========================
with PETSc.Log.Stage("Solver"):
  # Standard solver
  if opt == 0:
    solver.solve()
  # Optimization solver
  else:
    A = assemble(a, bcs=bcs, tensor=A)
    b = assemble(L, tensor=b)
    b -= rhs_opt
    for bc in bcs:
      bc.apply(b)
      
    # Optional initial guess
    with sol.dat.vec as sol_vec:
      taoSolver.solve(sol_vec)
  
#=========================
#  output
#=========================
if opt == 0:
  outfile = File("MCP_galerkin.pvd")
elif opt == 1:
  outfile = File("MCP_tron.pvd")
else:
  outfile = File("MCP_compl.pvd")
outfile.write(sol)
