from firedrake import *

N = 20

mesh = UnitSquareMesh(N,N)

V = FunctionSpace(mesh,"CG",2)

eta0 = Function(V).interpolate(Expression("cos(2*pi*x[0])*cos(2*pi*x[1])"))
phi0 = Function(V)

eta1 = Function(V)
phi1 = Function(V)
Lphi1 = Function(V)

T = 2.0

dt = 0.005

t = 0.0

eta = TrialFunction(V)
phi = TrialFunction(V)
gamma = TestFunction(V)

mu = 0.1

aphi = (gamma*phi + 0.5*mu*inner(grad(gamma),grad(phi)))*dx
Lphi = (gamma*phi0 + 0.5*mu*inner(grad(gamma),grad(phi0))
        - dt*gamma*eta0)*dx

phi_problem = LinearVariationalProblem(aphi,Lphi,phi1)
phi_solver = LinearVariationalSolver(phi_problem)

aLaplace = gamma*phi*dx
LLaplace = inner(grad(gamma),grad(phi1))*dx

Laplace_problem = LinearVariationalProblem(aLaplace,LLaplace,Lphi1)
Laplace_solver = LinearVariationalSolver(Laplace_problem)

aeta = (gamma*eta + 0.5*mu*inner(grad(gamma),grad(eta)))*dx
Leta = (gamma*eta0 + 0.5*mu*inner(grad(gamma),grad(eta0))
        + dt*inner(grad(gamma),grad(phi1))
        + 2.0/3.0*mu*dt*inner(grad(gamma),grad(Lphi1)))*dx

eta_problem = LinearVariationalProblem(aeta,Leta,eta1)
eta_solver = LinearVariationalSolver(eta_problem)

phi_file = File('phi.pvd')
eta_file = File('eta.pvd')

phi_file << phi0
eta_file << eta0

while(t<T-0.5*dt):
    print t, assemble(phi0*phi0*dx)
    t += dt

    phi_solver.solve()
    Laplace_solver.solve()    
    eta_solver.solve()

    eta0.assign(eta1)
    phi0.assign(phi1)

    phi_file << phi0
    eta_file << eta0
    

