from firedrake import *
import matplotlib.pyplot as plt

# Create mesh and define function space
mesh = UnitSquareMesh(64, 64)
V = FunctionSpace(mesh, "CG", 1)
R = FunctionSpace(mesh, "R", 0)
W = MixedFunctionSpace((V, R))

# Define variational problem
w = TrialFunction(W)
u, c = split(w)
v, d = TestFunctions(W)
#f = interpolate(Expression("10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)"),W.sub(0))
#g = interpolate(Expression("-sin(5*x[0])"),W.sub(0))
f = Function(V).interpolate(Expression("10*exp(-(pow(x[0] - 0.5, 2) + pow(x[1] - 0.5, 2)) / 0.02)"))
g = Function(V).interpolate(Expression("-sin(5*x[0])"))

a = (inner(grad(u), grad(v)) + c*v + u*d)*dx
L = f*v*dx + g*v*ds

F = a - L

# Compute solution
w = Function(W)
#solve(assemble(lhs(F)),w,assemble(rhs(F)))
solve(a == L, w)

u, c = w.split()

# Plot solution
plot(u)
plt.show()
