from firedrake import *
import numpy as np
import struct
from scipy import misc

p_0 = np.array(misc.imread("frame10.png"), dtype=float)
p_1 = np.array(misc.imread("frame11.png"), dtype=float)
pic_sequence = np.array([p_0,p_1])

height, width = p_0.shape
m = RectangleMesh(height,width,height,width,quadrilateral=True)
mesh = ExtrudedMesh(m, layers=1)

W_h = FiniteElement("DG", "quadrilateral", 0)
W_v = FiniteElement("CG", "interval", 1)
W_elt = TensorProductElement(W_h, W_v)
W = FunctionSpace(mesh, W_elt) # for image intensity
V = VectorFunctionSpace(mesh, "CG", 1, dim=2) # for optical flow field

x = SpatialCoordinate(mesh)
I = Function(W)
I.interpolate(x[0])
x_ind = np.vectorize(int)(np.array(I.vector()[:]))
I.interpolate(x[1])
y_ind = np.vectorize(int)(np.array(I.vector()[:]))
I.interpolate(x[2])
z_ind = np.vectorize(int)(np.array(I.vector()[:]))
pics = np.empty(x_ind.shape)
pics[:] = pic_sequence[z_ind, x_ind, y_ind]
I.vector()[:] = pics

Q = MixedFunctionSpace([W, W, V])
q = TestFunction(Q)
s = Function(Q)
phi, p, u = split(s)

alpha = 1.0
n = FacetNormal(mesh)
u_dot_n = u[0]*n[0]+u[1]*n[1]
un = 0.5*(u_dot_n + abs(u_dot_n))
L1 = (phi + I.dx(2))**2*dx
L2 = inner(grad(u), grad(u)) *dx
L3 = p*phi*dx + p*I*(u[0].dx(0)+u[1].dx(1))*dx - p*I*un*ds_v - jump(p)*(un('+')*I('+')-un('-')*I('-'))*dS_h
L = L1 + L2 + L3

bc3 = DirichletBC(Q.sub(2), Constant((0, 0)), (1,2,3,4))
bcs = [bc3]

F = derivative(L,s,q)

solve(F==0, s, bcs=bcs, nest=False,
      solver_parameters={'ksp_type': 'preonly',
                         'pc_type': 'lu', 
                         'snes_atol':1e-10,
                         'ksp_atol':1e-10})

