from firedrake import *
import numpy

mesh2d = UnitSquareMesh(6, 6)
mesh = ExtrudedMesh(mesh2d, layers=5, layer_height=0.2)

P1DG_2d = FunctionSpace(mesh2d, 'DG', 1)
P1DG = FunctionSpace(mesh, 'DG', 1)

phi_2d = Function(P1DG_2d, name='phi2d')
phi_3d = Function(P1DG, name='phi3d')

phi_2d.project(Expression('x[0]*sin(pi*x[1])'))

def copy_2d_field_to_3d(input2d, output3d):
    """Copy 2d function to 3d function."""

    fs_2d = input2d.function_space()
    fs_3d = output3d.function_space()

    iterate = op2.ALL

    # number of nodes in vertical direction
    nVertNodes = len(fs_3d.fiat_element.B.entity_closure_dofs()[1][0])

    nodes = fs_3d.bt_masks['geometric'][0]
    idx = op2.Global(len(nodes), nodes, dtype=numpy.int32, name='nodeIdx')
    kernel = op2.Kernel("""
        void my_kernel(double **func, double **func2d, int *idx) {
            for ( int d = 0; d < %(nodes)d; d++ ) {
                for ( int c = 0; c < %(func_dim)d; c++ ) {
                    for ( int e = 0; e < %(v_nodes)d; e++ ) {
                        func[idx[d]+e][c] = func2d[d][c];
                    }
                }
            }
        }""" % {'nodes': input2d.cell_node_map().arity,
                'func_dim': input2d.function_space().cdim,
                'v_nodes': nVertNodes},
                'my_kernel')
    op2.par_loop(
        kernel, fs_3d.mesh().cell_set,
        output3d.dat(op2.WRITE, fs_3d.cell_node_map()),
        input2d.dat(op2.READ, fs_2d.cell_node_map()),
        idx(op2.READ),
        iterate=iterate)

    return output3d

copy_2d_field_to_3d(phi_2d, phi_3d)

out_2d = File('phi2d.pvd')
out_3d = File('phi3d.pvd')

out_2d << phi_2d
out_3d << phi_3d
