from firedrake import *
mesh = UnitSquareMesh(20, 20, reorder=None, quadrilateral=False)
V = FunctionSpace(mesh, 'Lagrange', 1)
p = Function(V, name="p")
pn = Function(V, name="p")
pn_1 = Function(V, name="p last")
phi = Function(V, name="phi")
phin = Function(V, name="phi")
phin_1 = Function(V, name="phi last")
u = TrialFunction(V)
v = TestFunction(V)
f = Function(V).interpolate(Expression("x[0]*x[1]*x[2]*t"))
fn = Function(f, name="f")
fn_1 = Function(f, name="f last")
outfile = File("waveEqn1.pvd")
outfile << phi
bcval = Constant(0.0)
bc = DirichletBC(V, bcval, 1)
T = 10.
dt = 0.001
theta = 0.5
t = 0
step = 0
lump_mass = True
while t <= T:
    step += 1
    bcval.assign(sin(2*pi*5*t))
    phin = phin_1 - (dt * theta * pn + dt * (1 - theta) * pn_1)
    if lump_mass:
        pn = pn_1 + (assemble(dt * (theta * (inner(nabla_grad(v), nabla_grad(phin))*dx) + fn*v*dx) + \
                      ((1-theta) * (inner(nabla_grad(v), nabla_grad(phin_1))*dx) + fn_1*v*dx)) / assemble(v*dx))
        bc.apply(pn_1)
    else:
        solve(u * v * dx == v * pn * dx + dt * (theta * (inner(grad(v), grad(phin)) * dx + fn*V*dx)+ \
                                               (1-theta) * (inner(grad(v), grad(phin_1)) * dx) + fn_1*v*dx),\
              pn, bcs=bc, solver_parameters={'ksp_type': 'cg','pc_type': 'sor','pc_sor_symmetric': True})
    phin = phin_1 - (dt * theta * pn + dt * (1 - theta) * pn_1)
    phi = phin
    pn_1.assign(pn)
    fn_1.assign(fn)
    phin_1.assign(phin)
    t += dt
    if step % 10 == 0:
        outfile << phi                                                                                                                                 
                  
