from firedrake import *
mesh = UnitSquareMesh(20, 20, reorder=None, quadrilateral=False) 
V = FunctionSpace(mesh, 'Lagrange', 1)
p = Function(V, name="p")
pn = Function(V, name="p")
pn_1 = Function(V, name="p last")
phi = Function(V, name="phi")
phin = Function(V, name="phi")
phin_1 = Function(V, name="phi last")
u = TrialFunction(V)
v = TestFunction(V)
fn = Function(V, name="f")
fn_1 = Function(V, name="f last")
t = 0
f_expr = Expression("x[0]*x[1]*t", t = t)
fn.interpolate(f_expr)
fn_1.interpolate(f_expr)
outfile = File("waveEqn1.pvd")
outfile << phin
bcval = Constant(0.0)
bc = DirichletBC(V, bcval, 1)
T = 10.
dt = 0.001
theta = 0.5
step = 0
lump_mass = True
while t <= T:
    step += 1
    bcval.assign(sin(2*pi*5*t))
    phin = phin_1 - (dt * theta * pn + dt * (1 - theta) * pn_1)
    if lump_mass:
        pn = pn_1 + (assemble(dt * (theta * (inner(nabla_grad(v), nabla_grad(phin))*dx) + fn*v*dx) + \
                      dt*(1-theta) * (inner(nabla_grad(v), nabla_grad(phin_1))*dx + fn_1*v*dx)) / assemble(v*dx))
        bc.apply(pn_1)
    else:
        solve(u * v * dx == v * pn * dx + dt * theta * (inner(grad(v), grad(phin)) * dx + fn*V*dx)+ \
                                               dt*(1-theta) * (inner(grad(v), grad(phin_1)) * dx + fn_1*v*dx),\
              pn, bcs=bc, solver_parameters={'ksp_type': 'cg','pc_type': 'sor','pc_sor_symmetric': True})
    phin = phin_1 - (dt * theta * pn + dt * (1 - theta) * pn_1)
    pn_1.assign(pn)
    phin_1.assign(phin)
    t += dt
    fn_1.assign(fn)
    f_expr.t = t
    fn.interpolate(f_expr)
    if step % 10 == 0:
        outfile << phin                                                                                                                                 
                  
