from firedrake import *
import numpy as np
from scipy import stats
parameters['pyop2_options']['log_level'] = 'WARNING'
parameters['assembly_cache']['enabled'] = False

nx = 21  # number of cells
ny = 21
mean0 = [nx/2,ny/2]  # mean of 1st Gaussian hump
mean1 = [nx/2+1,ny/2]  # mean of 2nd Gaussian hump
sigma = 0.1*nx  # stdev
rv0 = stats.multivariate_normal(mean0, [[sigma**2, 0], [0, sigma**2]])
rv1 = stats.multivariate_normal(mean1, [[sigma**2, 0], [0, sigma**2]])
x, y = np.mgrid[0:nx:1, 0:ny:1]
pos = np.empty(x.shape + (2,))
pos[:, :, 0] = x; pos[:, :, 1] = y
scale = 1.0/rv0.pdf(mean0)
p0 = rv0.pdf(pos)*scale
p1 = rv1.pdf(pos)*scale

pic_sequence = np.array([p0,p1])

m = RectangleMesh(nx,ny,nx,ny,quadrilateral=True)
mesh = ExtrudedMesh(m, layers=1)

W_h = FiniteElement("DG", "quadrilateral", 0)
W_v = FiniteElement("CG", "interval", 1)
W_elt = TensorProductElement(W_h, W_v)
W = FunctionSpace(mesh, W_elt) # for image intensity
V_h = FiniteElement("CG", "quadrilateral", 1)
V_v = FiniteElement("DG", "interval", 0)
V_elt = TensorProductElement(V_h, V_v)
V = VectorFunctionSpace(mesh, V_elt, dim=2) # for optical flow field

x = SpatialCoordinate(mesh)
I = Function(W)

I.interpolate(x[0])
x_ind = I.vector()[:].astype(int)
I.interpolate(x[1])
y_ind = I.vector()[:].astype(int)
I.interpolate(x[2])
t_ind = I.vector()[:].astype(int)
pics = np.empty(x_ind.shape)
pics[:] = pic_sequence[t_ind, x_ind, y_ind]
I.vector()[:] = pics

Q = MixedFunctionSpace([W, W, V])
s = Function(Q)
q = TestFunction(Q)
phi, p, u = split(s)

alpha = Constant(80)  # regularisation constant
n = FacetNormal(mesh)
u_dot_n = u[0]*n[0] + u[1]*n[1] + n[2]
un = 0.5*(u_dot_n + abs(u_dot_n)) # outflow
un_minus = 0.5*(u_dot_n - abs(u_dot_n)) # inflow

L = phi**2*dx
L += alpha*inner(grad(u),grad(u))*dx
L += p*phi*dx
L += p*I*(u[0].dx(0)+u[1].dx(1))*dx  # remove if divergent free field
L += I*(p.dx(0)*u[0]+p.dx(1)*u[1]+p.dx(2))*dx
L += - p*I*u_dot_n*ds_t  # top
L += - p*I*u_dot_n*ds_b  # bottom
# L += - p*I*u_dot_n*ds_v  # outflow
L += - p*I*un*ds_v  # outflow
# L += - p*I_in*un_minus*ds_v  # inflow
L += - p*I*un_minus*ds_v  # inflow
L += - jump(p)*((un*I)('+')-(un*I)('-'))*dS_v

F = derivative(L,s,q)

solver_parameters = {'snes_type': 'newtonls',
                 	 'snes_linesearch_type': 'basic',
                 	 'snes_max_it': 10,
                     # 'snes_view': True,
                     # 'snes_fd': True,
                 	 'snes_monitor_short': True,
                 	 # 'snes_monitor_solution': ':solution.txt',
                 	 # 'snes_monitor_residual': ':residual.txt',
                 	 'ksp_type': 'preonly',
                 	 'pc_type': 'lu', 
                 	 'pc_factor_mat_solver_package': 'mumps',
                 	 'snes_atol': 1e-10,
                 	 'ksp_atol': 1e-10}
solve(F==0, s, nest=False, solver_parameters=solver_parameters)