from firedrake import *
m=20
mesh = IntervalMesh(100,-2,2)

# Mixed Space for coupled system
V = FunctionSpace(mesh, "DG", 2)
W = V*V

# Timestepper info for backwards euler
timestep = 0.05
# gravity constant
g = 9.81

# .split() to interpolate initial conditions
w_ = Function(W)
h_, hu_ = w_.split()

# Initial Conditions
starting_height =  Expression("(x[0] < 0.0) ? 1.0 : 0.12")
h_.interpolate(starting_height)

velocity = Expression("0.0")
hu_.interpolate(velocity)

#Set up element normal
n = FacetNormal(mesh)

# Boundary condition for dam break
# Fine until shocks hits boundary
# small t should be fine


# initial guess for nonlinear solver
w = Function(W)
w.assign(w_)




# Setting up residual form


(phi, xi) = TestFunctions(W)
# split() to use variables in ufl form of residual
h_, hu_ = split(w_)
h, hu = split(w)

# Roe averages for numerical flux 
hbar = (h('-') + h('+'))*0.5
hubar = (pow(h('-'), 0.5) * hu('-') + pow(h('+'), 0.5) * hu('+') ) / ( pow(h('-'), 0.5) + pow(h('+'), 0.5) ) 

# Residual weak form

L = (
((h - h_)/timestep * phi + hu*(phi.dx(0)))*dx + # height equation
( (hu- hu_)/timestep *xi + (0.5*g*h*h + (hu * hu)/h)*xi.dx(0))*dx + # momentum equation
( avg(hu) - max_value( hubar/hbar + pow(g*hbar, 0.5),hubar/hbar -pow(g*hbar, 0.5))*jump(h))*n[0]('-')*jump(phi)*dS + # height internal flux
( avg(0.5*g*h*h + hu*hu/h) - max_value( hubar/hbar + pow(g*hbar, 0.5),hubar/hbar -pow(g*hbar, 0.5))*jump(hu))*n[0]('-')*jump(xi)*dS # momentum internal flux
+(hu*phi)*ds(1) + (hu*phi)*ds(2)
+(0.5*g*1+ hu*hu/1)*xi*ds(1)+(0.5*0.12*0.12 + hu*hu/0.12)*xi*ds(2)
)

# Define problem and solver
# nest = False for mixed problem 


uprob = NonlinearVariationalProblem(L, w,   nest=False)
usolver = NonlinearVariationalSolver(uprob, solver_parameters=
   {'ksp_type': 'preonly',
    'pc_type': 'lu'})

# .split() to store results
h, hu = w.split()


# Storage for results 
# Write initial condition to file
hfile = File("./Results/height.pvd")
hufile = File("./Results/momentum.pvd")

# Write IC to file
hfile << h
hufile << hu

# Time information
t=0.0
end=.06


# time loop 

while (t <= end):
       #Solve problem
       usolver.solve()

       #Update previous timestep
       w_.assign(w)
       

       #Update time
       t+=timestep
       # output results
       hfile << h

