from firedrake import *
from ufl.geometry import Jacobian, JacobianDeterminant

# Set up
mesh = PeriodicSquareMesh(120, 120, 2, quadrilateral=True)
dt = 0.0001
outfile = File("out.pvd")

V1_elt = FiniteElement("RTCF", quadrilateral, 2)
V = FunctionSpace(mesh, V1_elt)
Q = FunctionSpace(mesh, "DG", 1)
W = MixedFunctionSpace((V, Q))

# Initialise
U0 = Function(W)
u0, h0 = U0.split()
f = Constant(100.)
x = SpatialCoordinate(mesh)
a = Constant(0.1)
delta_x = Constant(2/120.)
h0.interpolate(cos(pi*x[1]/delta_x)*cos(pi*x[0]/delta_x)*exp(-((x[0]-1)**2 + (x[1]-1)**2)/a**2))
w, phi = TestFunctions(W)
u, h = TrialFunctions(W)
U1 = Function(W)
U1.assign(U0)
u1, h1 = U1.split()
outfile.write(u1, h1)

J = Jacobian(mesh)
detJ = JacobianDeterminant(mesh)
detJ_inv = (1./detJ)

aeqn = (
    inner(dot(J, w), dot(J, u))*detJ_inv*dx
    + 0.5*dt*f*inner(w, perp(u))*dx
    + phi*h*detJ*dx - 0.5*dt*h*div(w)*dx + 0.5*dt*phi*div(u)*dx
)

Leqn = (
    inner(dot(J, w), dot(J, u0))*detJ_inv*dx
    - 0.5*dt*f*inner(w, perp(u0))*dx
    + phi*h0*detJ*dx + 0.5*dt*div(w)*h0*dx - 0.5*dt*phi*div(u0)*dx
)

prob = LinearVariationalProblem(aeqn, Leqn, U1)
solver = LinearVariationalSolver(prob)
solver.solve()
u1, h1 = U1.split()
outfile.write(u1, h1)
