from firedrake import *
from firedrake_adjoint import *
from pyadjoint.placeholder import Placeholder

# see http://www.dolfin-adjoint.org/en/latest/documentation/mpec/mpec.html



def smoothmax(r, eps=1e-4):
    return conditional(gt(r, eps), r - eps/2, conditional(lt(r, 0), 0, r**2 / (2*eps)))

mesh = UnitSquareMesh(128, 128)
x = SpatialCoordinate(mesh)
V = FunctionSpace(mesh, "CG", 1)

y = Function(V, name="Solution")
u = Function(V, name="Control")
w = TestFunction(V)

alpha = Constant(1e-2)
Placeholder(alpha)


f = interpolate(-abs(x[0]*x[1] - 0.5) + 0.25, V)
F = inner(grad(y), grad(w))*dx - 1 / alpha * inner(smoothmax(-y), w)*dx - inner(f + u, w)*dx
bc = DirichletBC(V, 0.0, "on_boundary")
solve(F == 0, y, bcs=bc)

nu = 0.01
yd = f.copy(deepcopy=True)
J = assemble(0.5 * inner(y - yd, y - yd) * dx +  nu * inner(u, u) * dx)

m = Control(u) 
ic = Control(y)
Jhat = ReducedFunctional(J, m)

# Create output files
ypvd = File("output/y_opt.pvd")
upvd = File("output/u_opt.pvd")


for i in range(4):
    
    alpha.assign(float(alpha) / 2)
    print("Set alpha to %f." % float(alpha))
    

    u_opt = minimize(Jhat, method="L-BFGS-B", bounds=(0.01, 0.03), options={"gtol": 1e-12, "ftol": 1e-100})
    
    y_opt =  Control(y).tape_value()
    ic.update(y_opt)
    
    
    print("Feasibility: ", norm(Max(Constant(0.0), -y_opt)))
    print("Norm of y  ", norm(y_opt))
    print("Norm of u_opt : ", norm(u_opt))
    
ypvd.write(y_opt)
upvd.write(u_opt)

tape = get_working_tape()
tape.clear_tape()    
