
from firedrake import *

# import mpi4py
# from mpi4py import MPI
# print('commsize',MPI.COMM_WORLD.size)

order = 1
nx = 50
ny = 50
Lx = 2.0
Ly = 2.0
g = 1.
H0 = 1.
f = 25.

dxdof = 0.5 * Lx / nx / order
dt = 0.5 * dxdof / sqrt(g * H0)

nsteps = 5

#create mesh and spaces
mesh = PeriodicRectangleMesh(nx,ny,Lx,Ly,quadrilateral=True,direction='both')

hdivelem = FiniteElement("RTCF", quadrilateral, order)
l2elem = FiniteElement("DG", quadrilateral, order-1)

hdivspace = FunctionSpace(mesh,hdivelem)
l2space = FunctionSpace(mesh,l2elem)

mixedspace = MixedFunctionSpace([l2space,hdivspace])

#set initial condition
x,y = SpatialCoordinate(mesh)
dh = 1.
a = 1/10.
x0 = 1.0
y0 = 1.0
hval = dh * exp(-((x-x0)*(x-x0) + (y-y0)*(y-y0))/(a*a))
h0 = Function(l2space,name='h0')
hprojector = Projector(hval,h0)
hprojector.project()
v0 = Function(hdivspace,name='v0')

#create functions and test functions
xn = Function(mixedspace,name='xn')
xnp1 = Function(mixedspace,name='xnp1')
hhat,vhat = TestFunctions(mixedspace)

#assign initial condition
hnsplit,vnsplit = xn.split()
hnp1split,vnp1split = xnp1.split()
hnsplit.assign(h0)
hnp1split.assign(h0)
vnsplit.assign(v0)
vnp1split.assign(v0)

#checkpoint = DumbCheckpoint('test',mode=FILE_CREATE)
#checkpoint.set_timestep(0)
#checkpoint.store(hnsplit,name='h')
#checkpoint.store(vnsplit,name='v')

#create variational problem + solver
hn,vn = split(xn)
hnp1,vnp1 = split(xnp1)

Mh = inner(hhat,(hnp1 - hn)/dt) * dx
Mv = inner(vhat,(vnp1 - vn)/dt) * dx
Sh = -inner(hhat,H0 * div((vnp1+vn)/2.)) * dx
Sv = inner(div(vhat),g * (hnp1+hn)/2.) * dx + inner(vhat,f*perp((vnp1+vn)/2.)) * dx
R = Mh + Mv + Sh + Sv

problem = NonlinearVariationalProblem(R,xnp1)
solver = NonlinearVariationalSolver(problem)
#,solver_parameters={'snes_type': 'ksponly','ksp_converged_reason': '','ksp_type': 'lgmres', 'pc_type':'jacobi'}) #options_prefix = 'linsys_',

	
for i in range(1,nsteps+1):
	print(i)
	solver.solve()
	xn.assign(xnp1)
	#checkpoint.set_timestep(i)
	#checkpoint.store(hnsplit,name='h')
	#checkpoint.store(vnsplit,name='v')

#checkpoint.close()
