from firedrake import *

N = 20

mesh = UnitSquareMesh(N,N)

V = FunctionSpace(mesh,"CG",2)

ic_eta = project(Expression("cos(2*pi*x[0])*cos(2*pi*x[1])"), V)
ic_phi = project(Expression("0*x[0]"), V)

#eta0 = Function(V).assign(ic_eta)
#phi0 = Function(V).assign(ic_phi)

eta0 = Function(V).interpolate(Expression("cos(2*pi*x[0])*cos(2*pi*x[1])"))
phi0 = Function(V)

#eta1 = Function(V).assign(ic_eta)
#phi1 = Function(V).assign(ic_phi)
eta1 = Function(V)
phi1 = Function(V)
Lphi1 = Function(V)

T = 0.1

dt = 0.005

t = 0.0

eta = Function(V)
phi = Function(V)
gamma = TestFunction(V)

mu = 0.1

Fphi = (gamma*phi + 0.5*mu*inner(grad(gamma),grad(phi))
        - (gamma*phi0 + 0.5*mu*inner(grad(gamma),grad(phi0))
        + dt*gamma*eta0))*dx

print 'test1'

phi_problem = NonlinearVariationalProblem(Fphi,phi1)
phi_solver = NonlinearVariationalSolver(phi_problem)
#,solver_parameters={'snes_type': 'ksponly','snes_monitor': True})

print 'test2'

FLaplace = (gamma*phi - inner(grad(gamma),grad(phi1)))*dx

Laplace_problem = NonlinearVariationalProblem(FLaplace,Lphi1)
Laplace_solver = NonlinearVariationalSolver(Laplace_problem)
#,solver_parameters={'snes_type': 'ksponly','snes_monitor': True})

Feta = (gamma*eta + 0.5*mu*inner(grad(gamma),grad(eta))
        - (gamma*eta0 + 0.5*mu*inner(grad(gamma),grad(eta0)))
        - dt*inner(grad(gamma),grad(phi1))
        - 2.0/3.0*mu*dt*inner(grad(gamma),grad(Lphi1)))*dx

eta_problem = NonlinearVariationalProblem(Feta,eta1)
eta_solver = NonlinearVariationalSolver(eta_problem)
#,solver_parameters={'snes_type': 'ksponly','snes_monitor': True})

phi_file = File('phi.pvd')
eta_file = File('eta.pvd')

phi_file << phi0
eta_file << eta0

while(t<T-0.5*dt):
    print t, assemble(phi0*phi0*dx)
    t += dt

    phi_solver.solve()
    Laplace_solver.solve()    
    eta_solver.solve()

    eta0.assign(eta1)
    phi0.assign(phi1)

    phi_file << phi0
    eta_file << eta0
    

