from firedrake import *

## Create mesh and functions ##
mesh = UnitSquareMesh(40,40)
#P = FunctionSpace(mesh, 'CG', 1)
P = FunctionSpace(mesh, 'DG', 1)
u = TrialFunction(P)
v = TestFunction(P)
u_ = Function(P)
du = Function(P)

## Define the expressions ##
x,y = SpatialCoordinate(mesh)
f = Function(P).interpolate(-100.0*exp(-(pow(x - 0.5, 2) + pow(y - 0.5, 2)) / 0.02))
u0 = Function(P).interpolate(x + 0.25*sin(2*pi*y))
g = Function(P).interpolate(-sin(5*x))

## Dirichlet boundary conditions ##
#bcs = DirichletBC(P,u0,(1,2))
bcs = []

## Nonlinear term ##
def D(c):
    return Constant(100) - Constant(1.0)/c

## DG terms ##
n = FacetNormal(mesh)
h = CellDiameter(mesh)
h_avg = (h('+') + h('-'))/2
alpha = 4.0
gamma = 8.0

## Variational form ##
#F = dot(grad(v), D(u_)*grad(u_))*dx - v*f*dx - g*v*ds(3) - g*v*ds(4)
F = dot(grad(v), D(u_)*grad(u_))*dx \
    - dot(avg(D(u_)*grad(v)), jump(u_, n))*dS \
    - dot(jump(v, n), avg(D(u_)*grad(u_)))*dS \
    + alpha/h_avg*dot(jump(v, n), jump(u_, n))*dS \
    - dot(grad(v), u_*n)*(ds(1)+ds(2)) \
    - dot(v*n, grad(u_))*(ds(1)+ds(2)) \
    + (gamma/h)*v*u_*(ds(1)+ds(2)) \
    - v*f*dx + u0*dot(grad(v), n)*(ds(1)+ds(2)) \
    - (gamma/h)*u0*v*(ds(1)+ds(2)) - g*v*(ds(3)+ds(4))
#J = derivative(F,u_)

## Solver parameters ##
# DOES NOT WORK #
solver_params1 = {
    'ksp_type': 'gmres',
    'pc_type': 'hypre',
    'ksp_monitor_true_residual': True,
    'snes_monitor': True
}

# DOES WORK #
solver_params2 = {
    'ksp_type': 'gmres',
    'pc_type': 'lu',
    'ksp_monitor_true_residual': True,
    'snes_monitor': True
}

## Initial non-zero guess ##
u_.assign(1e-1)

## Solve problem ##
solve(F == 0, u_, bcs=bcs, solver_parameters=solver_params2)

## Plot image ##
try:
    import matplotlib.pyplot as plt
except:
    warning("Matplotlib not imported")
try:
    plot(u_)
except Exception as e:
    warning("Cannot plot figure. Error msg '%s'" % e)
try:
     plt.show()
except Exception as e:
    warning("Cannot show figure. Error msg '%s'" % e)
