from firedrake import *
from ffc import log
log.set_level(log.ERROR)
op2.init(log_level="WARNING")
from firedrake.ffc_interface import compile_form
from firedrake.fiat_utils import *

order_vertical = 0

def build_lma(mesh,ufl_form):
    compiled_form = compile_form(ufl_form, 'ufl_form')[0]
    kernel = compiled_form[6]
    coords = compiled_form[3]
    coefficients = compiled_form[4]
    arguments = ufl_form.arguments()
    assert len(arguments) == 2, 'Not a bilinear form'
    nrow = arguments[0].cell_node_map().arity
    ncol = arguments[1].cell_node_map().arity
    V_lma = FunctionSpace(mesh,'DG',0)
    lma = Function(V_lma, val=op2.Dat(V_lma.node_set**(nrow*ncol)))
    args = [lma.dat(op2.INC, lma.cell_node_map()[op2.i[0]]), 
            coords.dat(op2.READ, coords.cell_node_map(), flatten=True)]
    for c in coefficients:
        args.append(c.dat(op2.READ, c.cell_node_map(), flatten=True))
    op2.par_loop(kernel,lma.cell_set, *args)
    return lma

        
# Extruded mesh
ncells=3
host_mesh = CircleManifoldMesh(ncells)
D = 0.1
nlayers = 1 
mesh = ExtrudedMesh(host_mesh,
                    layers=nlayers,
                    extrusion_type='radial',
                    layer_height=D/nlayers)

# Finite elements
# Horizontal
U2 = FiniteElement('DG',interval,0)
        
# Vertical
V0 = FiniteElement('CG',interval,1)
V1 = FiniteElement('DG',interval,order_vertical)


# Function spaces
W2_vert = FunctionSpace(mesh,HDiv(OuterProductElement(U2,V0)))
W3 = FunctionSpace(mesh,OuterProductElement(U2,V1))

# UFL Forms for vertical derivatives
phi = TestFunction(W3)
psi = TrialFunction(W3)
u = TestFunction(W2_vert)
w = TrialFunction(W2_vert)

form_D = phi*div(w)*dx
form_DT = div(u)*psi*dx

# Build local matrix representation
lma_DT = build_lma(mesh,form_DT)
lma_D = build_lma(mesh,form_D)

print '*** D:  W2_vert -> W3 ***'
print lma_D.dat.data
print '*** DT: W3 -> W2_vert ***'
print lma_DT.dat.data

