#===============================================================
#
#  Diffusion example 2: ABC flow with Dispersion tensor 
#                       Nonhomogeneous BC on top face
#                       Homogeneous everywhere else
#                       Solution sol must be 0.0 < sol < 1.0
#
#  This one does not
#
#  Run as:
#    python SS_ex2.py <seed> <guess>
#
#  <seed> = number of elements in x/y/z direction
#  <guess> = Zero initial guess (0) or solution w/o bounds
#            (1) as initial guess
#
#===============================================================
from firedrake import *
from firedrake.petsc import PETSc
from ufl import log
from coffee.logger import set_log_noperf
import numpy as np, sys, time
rank = PETSc.COMM_WORLD.getRank()
op2.init(log_level='ERROR')
set_log_noperf()
log.set_level(log.CRITICAL)
parameters["assembly_cache"]["enabled"] = True

# Required command-line options
seed = int(sys.argv[1])
guess = int(sys.argv[2])

# Bounds
lower = 0.0
upper = 1.0

#===============#
#  Create mesh  #
#===============#
meshbase = UnitSquareMesh(seed,seed,quadrilateral=True)
mesh = ExtrudedMesh(meshbase,seed)
V = FunctionSpace(mesh, "CG", 1)
Q = VectorFunctionSpace(mesh, 'CG', 1)
u = TrialFunction(V)
v = TestFunction(V)
sol = Function(V, name="solution")
x_ = Function(V)

#===================#
#  ABC flow params  #
#===================#
A = 0.3
B = 0.65
C = 1
Ac = 5
As = 2
Bc = 6
Bs = 2
Cc = 3
Cs = 4
vel_ABC = interpolate(Expression((
  "A*sin(As*pi*x[2])+C*cos(Cc*pi*x[1])",
  "B*sin(Bs*pi*x[0])+A*cos(Ac*pi*x[2])",
  "C*sin(Cs*pi*x[1])+B*cos(Bc*pi*x[0])"
  ),A=A,B=B,C=C,Ac=Ac,As=As,Bc=Bc,Bs=Bs,
  Cc=Cc,Cs=Cs),Q)
velocity = vel_ABC

#=====================#
#  Volumetric source  #
#=====================#
f = Constant(0.0)

#=====================#
#  Dispersion tensor  #
#=====================#
aT = 1e-5
aL = 1e-1
aD = 1e-9
alpha_t = Constant(aT)
alpha_L = Constant(aL)
alpha_D = Constant(aD)
normv = sqrt(dot(velocity,velocity))
Id = Identity(mesh.geometric_dimension())
D = (alpha_D + alpha_t*normv)*Id+(alpha_L-alpha_t)*outer(velocity,velocity)/normv

#=======================#
#  Boundary conditions  #
#=======================#
bc1 = DirichletBC(V, Constant(0.0), (1,2,3,4))
bc2 = DirichletBC(V, Constant(0.0), "bottom")
bc3 = DirichletBC(V, Expression(("x[0] >= 0.2 && x[0] <= 0.8 && x[1] >= 0.2 && x[1] <= 0.8 ? sin(pi*(x[0]-0.2)/0.6)*sin(pi*(x[1]-0.2)/0.6) : 0.0")), "top")
bcs = [bc1, bc2, bc3]

#====================#
#  Variational form  #
#====================#
a = inner(D * grad(u), grad(v)) * dx(degree=(3,3))
L = v * f * dx(degree=(3,3))
F = inner(D * grad(sol), grad(v)) * dx(degree=(3,3)) - v * f * dx(degree=(3,3))

#=========================#
#  Default PETSc options  #
#=========================#
petsc_options = PETSc.Options()
petsc_options.setValue("-tao_monitor",None)
petsc_options.setValue("-tao_max_it","1000")
petsc_options.setValue("-snes_max_it","1000")
petsc_options.setValue("-snes_converged_reason",None)
petsc_options.prefixPush("linear_")
petsc_options.setValue("-ksp_converged_reason",None)
petsc_options.prefixPop()

#==========================#
#  Default solver options  #
#==========================#
solver_parameters = {
  'ksp_type': 'cg',
  'pc_type': 'hypre',
  'pc_hypre_type': 'boomeramg',
  'pc_hypre_boomeramg_strong_threshold': '0.75',
  'pc_hypre_boomeramg_agg_nl': '2',
  'ksp_atol': '1e-50',
  'ksp_rtol': '1e-7',
  'snes_atol': '1e-8'
}

#====================================#
#  Linear problem for initial guess  #
#====================================#
Lin_problem = LinearVariationalProblem(a, L, sol, bcs=bcs)
Lin_solver = LinearVariationalSolver(Lin_problem, solver_parameters=solver_parameters, options_prefix="linear_")

#=====================#
#  Semismooth solver  #
#=====================#
A = assemble(a, bcs=bcs)
r = assemble(F)
taoSolver = PETSc.TAO().create(PETSc.COMM_WORLD)
taoSolver.setOptionsPrefix("")

## Apply bcs for residual/function
#tmp = Function(V)
#for bc in bcs:
#  bc.apply(tmp)
#rhs_opt = assemble(action(a, tmp))
#r -= rhs_opt
#for bc in bcs:
#  bc.apply(r)
  
# Variable bounds
lb = Function(V)
ub = Function(V)
with lb.dat.vec as lb_vec, ub.dat.vec as ub_vec:
  lb_vec.set(lower)
  ub_vec.set(upper)
  taoSolver.setVariableBounds(lb_vec, ub_vec)

# Jacobian J = A
def formJac(tao, petsc_x, petsc_J, petsc_JP, A=None, a=None, bcs=None):
  A = assemble(a, bcs=bcs, tensor=A)
  A.M._force_evaluation()
  assert petsc_J == A.M.handle

def formGrad(tao, petsc_x, petsc_g, r=None, bcs=None, F=None, sol=None):
  r = assemble(F,tensor=r)
  with r.dat.vec as r_vec:
     r_vec.copy(petsc_g)

# Gradient g = F = A*x - f
con = Function(V)
with con.dat.vec as con_vec:
  taoSolver.setConstraints(formGrad, con_vec, kargs={'r': r, 'bcs': bcs, 'F': F, 'sol': sol})
taoSolver.setJacobian(formJac,A._M.handle, kargs={'A': A, 'a': a, 'bcs': bcs}) 
taoSolver.setType(PETSc.TAO.Type.SSFLS)

# Default optimization solver options
petsc_options.setValue("-ksp_type","stcg")
petsc_options.setValue("-pc_type","hypre")
petsc_options.setValue("-pc_hypre_type","boomeramg")
petsc_options.setValue("-pc_hypre_boomeramg_strong_threshold","0.75")
petsc_options.setValue("-pc_hypre_boomeramg_agg_nl","2")
taoSolver.getKSP().setTolerances(rtol=1e-7,atol=1e-50)
taoSolver.setTolerances(gatol=1e-8)
taoSolver.setFromOptions()

#=====================#
#  Solve the problem  #
#=====================#
if guess == 1:
  Lin_solver.solve()
initialTime = time.time()
for bc in bcs:
  bc.apply(sol)
with sol.dat.vec as sol_vec:
  taoSolver.solve(sol_vec)
getTime = time.time() - initialTime

#=========================
#  Output
#=========================
if rank == 0:
  print '\tWall-clock time: %.3e seconds' % getTime
file_prefix = 'Figures/'+ __file__.rsplit('.',1)[0]+'_'
outfile = File(file_prefix + 'solution.pvd')
outfile.write(sol)

