from firedrake import *

mesh = UnitSquareMesh(20,80)
coords = mesh.coordinates
coords.dat.data[:,1] = 4*coords.dat.data[:,1]

V = FunctionSpace(mesh,"CG",2)

eta0 = Function(V).interpolate(Expression("cos(2*pi*x[0])*cos(2*pi*x[1])"))
phi0 = Function(V)

eta1 = Function(V)
phi1 = Function(V)
Lphi1 = Function(V)

gamma = TestFunction(V)

mu = 0.1
epsilon = 0.25 #epsilon=0 is the linear problem 
T = 0.1
dt = 0.005
t = 0.0

Fphi = ( gamma*(phi1-phi0)/dt + 0.5*mu*inner(grad(gamma),grad((phi1-phi0)/dt))
        + gamma*eta0 - 0.5*epsilon*phi0*div(gamma*grad(phi0)))*dx

phi_problem = NonlinearVariationalProblem(Fphi,phi1)
phi_solver = NonlinearVariationalSolver(phi_problem,solver_parameters=
                                        {'snes_monitor': True,
                                         'ksp_monitor': True,
                                         'snes_linesearch_monitor': True})

FLaplace = ( gamma*Lphi1 - inner(grad(gamma),grad(phi1)) )*dx

Laplace_problem = NonlinearVariationalProblem(FLaplace,Lphi1)
Laplace_solver = NonlinearVariationalSolver(Laplace_problem,solver_parameters=
                                         {'snes_monitor': True,
                                         'ksp_monitor': True,
                                         'snes_linesearch_monitor': True})

Feta = ( gamma*(eta1-eta0)/dt + 0.5*mu*inner(grad(gamma),grad((eta1-eta0)/dt))
         - (1+epsilon*eta0)*inner(grad(gamma),grad(phi1))
         - 2.0/3.0*mu*inner(grad(gamma),grad(Lphi1)) )*dx

eta_problem = NonlinearVariationalProblem(Feta,eta1)
eta_solver = NonlinearVariationalSolver(eta_problem,solver_parameters=
                                        {'snes_monitor': True,
                                         'ksp_monitor': True,
                                         'snes_linesearch_monitor': True})

phi_file = File('phi.pvd')
eta_file = File('eta.pvd')

phi_file << phi0
eta_file << eta0

while(t<T-0.5*dt):
    print t, assemble(phi0*phi0*dx)
    t += dt

    phi_solver.solve()
    Laplace_solver.solve()    
    eta_solver.solve()

    eta0.assign(eta1)
    phi0.assign(phi1)

    phi_file << phi0
    eta_file << eta0
    

