#=============================================;
#  Run the file as:                           ;
#    python H_conv_barus.py <EQN> <FEM> <nx>  ;
#                                             ;
#    EQN = Darcy/Barus                        ;
#    FEM = RT/BDM/VMS                         ;
#    nx = No. of elements in each dim         ;
#=============================================;
from firedrake import *
from firedrake.petsc import PETSc
import numpy as np
import sys
EQN = sys.argv[1]
FEM = sys.argv[2]
nx = int(sys.argv[3])

#===============;
#  Create mesh  ;
#===============;
mesh = UnitSquareMesh(nx,nx,quadrilateral=False)

#====================================================;
#  Define function spaces and mixed (product) space  ;
#====================================================;
if FEM == 'VMS':
  velSpace = VectorFunctionSpace(mesh,"CG",1)
  pSpace = FunctionSpace(mesh, "CG", 1)
else:
  velSpace = FunctionSpace(mesh,FEM,1)
  pSpace = FunctionSpace(mesh, "DG", 0)
W = velSpace * pSpace

#===================================;
#  Define trial and test functions  ;
#===================================;
(v,p) = TrialFunctions(W)
(w,q) = TestFunctions(W)
solution = Function(W)
x,y = SpatialCoordinate(mesh)

#=====================;
#  Define body force  ;
#=====================;
if EQN == 'Darcy':
  rhob = project(Expression(('sin(pi*x[0])*cos(pi*x[1]) + pi*cos(pi*x[0])*sin(pi*x[1])','-cos(pi*x[0])*sin(pi*x[1]) + pi*sin(pi*x[0])*cos(pi*x[1])')),velSpace)
  barus_coeff = 0.0
elif EQN == 'Barus':
  rhob = project(Expression(('sin(pi*x[0])*cos(pi*x[1]) + sin(pi*x[1]) * (sin(pi*x[0])* sin(pi*x[0])* cos(pi*x[1]) + pi*cos(pi*x[0]))','-cos(pi*x[0])*sin(pi*x[1]) - sin(pi*x[0])* (sin(pi*x[1])* cos(pi*x[0])* sin(pi*x[1]) - pi*cos(pi*x[1]))')),velSpace)
  barus_coeff = 1.0

#============================;
#  Define medium properties  ;
#============================; 
K = Constant(1.0)
Kinv = Constant(1.0)
mu_0 = Constant(1.0)
beta = Constant(barus_coeff)

#====================;
#  Drag coefficient  ;
#====================;
def f(p):
  return mu_0*(1+beta*p)*Kinv
def finv(p):
  return (1.0/(mu_0*(1+beta*p)))*K

#======================;
#  Boundary conditions ;
#======================;
if FEM == 'VMS':
  bc1 = DirichletBC(W.sub(0).sub(0), Constant(0.0), (1,))
  bc2 = DirichletBC(W.sub(0).sub(1), Constant(0.0), (3,))
  bcs = [bc1,bc2]
else:
  bc1 = DirichletBC(W.sub(0), as_vector([sin(pi*x)*cos(pi*y),0.0]), (1,))
  bc2 = DirichletBC(W.sub(0), as_vector([0.0,-sin(pi*y)*cos(pi*x)]), (3,))
  bcs = [bc1,bc2]

#===========================;
#  Define variational form  ;
#===========================;
(v,p) = TrialFunctions(W)
(w,q) = TestFunctions(W)
solution = Function(W)
vel,pres = solution.split()

# Residual
if FEM == 'VMS':
  F = dot(w,f(pres)*vel)*dx - div(w)*pres*dx - q*div(vel)*dx - dot(w,rhob)*dx - 0.5*dot((f(pres)*w + grad(q)),finv(pres)*(f(pres)*vel + grad(pres) - rhob))*dx
else:
  F = dot(w,f(pres)*vel)*dx - div(w)*pres*dx - q*div(vel)*dx - dot(w,rhob)*dx

# Jacobian
Jv = derivative(F, vel, v)
Jp = derivative(F, pres, p)
J = Jv + Jp

#=================;
#  Solve problem  ;
#=================; 
selfp_parameters = {
  'ksp_type': 'gmres',
  'pc_type': 'fieldsplit',
  'pc_fieldsplit_type': 'schur',
  'pc_fieldsplit_schur_fact_type': 'upper',
  'pc_fieldsplit_schur_precondition': 'selfp',
  'fieldsplit_0_ksp_type': 'preonly',
  'fieldsplit_0_pc_type': 'bjacobi',
  'fieldsplit_0_sub_pc_type': 'ilu',
  'fieldsplit_1_ksp_type': 'preonly',
  'fieldsplit_1_pc_type': 'hypre',
  "snes_max_it": 200  
}
solve(F == 0, solution, bcs=bcs, J=J, options_prefix="hconv_", solver_parameters = selfp_parameters)

#===============;
#  Output file  ;
#===============;
outfile = File(EQN + '_' + FEM + '_solution.pvd')
outfile.write(vel,pres)

#=================;
#  h-convergence  ;
#=================; 
p_exact = project(Expression('sin(pi*x[0])* sin(pi*x[1])'),pSpace)
u_exact = project(Expression(('sin(pi*x[0])* cos(pi*x[1])','-cos(pi*x[0])* sin(pi*x[1])')),velSpace)
outfile_exact = File(EQN + '_' + FEM + '_exact.pvd')
outfile_exact.write(u_exact,p_exact) 
L2_v = errornorm(u_exact,vel,norm_type='L2',degree_rise= None)
L2_p = errornorm(p_exact,pres,norm_type='L2',degree_rise= None)
print "L2 error in pres = %1.3e" % L2_p
print "L2 error in vel = %1.3e"% L2_v
