# First we import all modules from Firedrake
from firedrake import *
# Then we attempt to import matplotlib as it is a better visualisation tool 
# than paraview for 1d systems.
try:
  import matplotlib.pyplot as plt
except:
  warning("Matplotlib not imported")

# Initialise mesh
# Set up 100 elements on an [0,1] unit line.

m=20
mesh = UnitIntervalMesh(m)

t=0.0
end=1.0

# Define Function space on our mesh.
# Initially we will use a continuous linear Lagrange basis

V = FunctionSpace(mesh, "CG", 1)


# Define timestep value

Dt = 0.1
dt = Constant(Dt)

# Define Crank Nicholson parameter
theta = 1.0

# Define storage for all u values for matplotlib
all_u = []

# Define initial condition and interpolate a hat function onto the mesh
# u_prev will be equivalent to u^n in our timestepping scheme.
u_prev = Function(V)
u_prev.interpolate( Expression(" (x[0] < 0.5) ? 2*x[0] : 2 - 2*x[0]"))

all_u.append(Function(u_prev))

outfile = File("./Results/diffusion.pvd")

# Write IC to file
outfile.write(u_prev , t = t )



# Define trial and test functions on this function space
# u will be the equivalent to u^n+1 in our timestepping scheme

u = TrialFunction(V)
phi = TestFunction(V)


# Define bilinear form a(u,phi) = L(phi)
# Use Crank-Nicholson time stepper
# First make a definition for the repeated parts in the bilinear form

def flux ( u , phi ):
	return dot(grad(u), grad(phi)) 



a = (u * phi/dt + theta * flux( u, phi ) )*dx
L = (u_prev * phi/dt - (1-theta)* flux(u_prev , phi ) )* dx

bc1 = DirichletBC(V, Expression('0.0'), 1)
bc2 = DirichletBC(V, Expression('0.0'), 2)

out = Function(V)

u_problem =  LinearVariationalProblem( a , L , out, bcs=[bc1,bc2] )
u_solver = LinearVariationalSolver(u_problem)


while (t < end):
	# First we increase time
	t += Dt

	# Print to console the current time
	print "Time is %g" % (t)

	# Use the solver and then update values for next timestep
	u_solver.solve()
	u_prev.assign(out)

	# Write output to file for paraview visualisation
	outfile.write(u_prev , t = t )

	# Array of values for matplotlib
	all_u.append(Function(out))


try:
  plot(all_u, axes= [0.0,1.5,0.0,1.5])
except Exception as e:
  warning("Cannot plot figure. Error msg: '%s'" % e.message)

try:
  plt.show()
except Exception as e:
  warning("Cannot show figure. Error msg: '%s'" % e.message)
