# script solving the heat equation
#   u_t = nu*nabla(u)
# using implicit Euler.
#
# Variant1 solves the weak formulation
# (u_np1, v) + dt*nu*inner( grad(u_np1), grad(v)) = (u_n0, v)
# with the term from the previous time step as a right hand side.
#
# Variant2 instead (following the nonlinear Burgers' equation example on the website) tries to solve
# ( (u_np1 - u_n0)/dt, v ) + nu*inner( grad(u_np1), grad(v) ) = 0
# but crashes with an error
# >> ValueError: Provided residual is not a linear form

from firedrake import *

# Set to False for selection variant2
variant1 = True

nu = 0.1

nx       = 40
t        = 0.0
tend     = 0.5
nsteps   = 40
timestep = tend/float(nsteps)

mesh = UnitSquareMesh(nx, nx)

V     = FunctionSpace(mesh, "CG", 2)
u     = TrialFunction(V)
v     = TestFunction(V)

temp = Function(V, name="Temperature")
u_   = Function(V, name="TemperatureOld")
u_ex = Function(V, name="TemperatureExact")

# Variant 1: works
if variant1:
  a = (inner(u, v) + timestep * nu * inner(grad(u), grad(v))) * dx
# Variant 2: does not work
else:
  a = (inner((u - u_)/timestep, v) + nu*inner(grad(u), grad(v))) * dx

# assign initial values
ic = project(Expression("cos(pi*x[0])*cos(pi*x[1])"), V)
temp.assign(ic)
u_.assign(ic)

while (t < tend):

  # Variant 1: works
  if variant1:
    L = inner(u_, v) * dx
    solve(a == L, temp)
  else:
  # Variant 2: does not work
    solve(a == 0, temp)
  
  u_.assign(temp)
  t += timestep

# Compute error
u_ex.interpolate(Expression("exp(-2.0*pi*pi*t)*cos(x[0]*pi)*cos(x[1]*pi)", t=nu*tend))
print sqrt(assemble(dot(temp - u_ex, temp - u_ex) * dx))
