from __future__ import division
from firedrake import *

# Model parameters
dt = 5.0e-04  	# time step
Cn=0.03			# Cahn
We=0.45			# Weber
Pe=45000			# Peclet
Re=100			# Reynolds
gravity = True		# switch for gravity 1 = ON, 0 = Off
Fr2=0.1			# Froude number squared
beta = 2		# stabilising term

rho1 = Constant(2)
rho2 = Constant(1)
drho_hat = Constant(0)
drho_hat.assign(0.5*(1 - (rho2/rho1)))	# d rho_hat / d phi
alpha = (rho2-rho1)/(rho2+rho1)

# Create mesh and define function spaces
mesh = UnitSquareMesh(100, 100)
P1 = FunctionSpace(mesh, "Lagrange", 1)
P2 = VectorFunctionSpace(mesh, "Lagrange", 2)

# Taylor-Hood elements for velocity and pressure, P1-P1 elements for phi and mu
W = P2*P1*P1*P1

# Dirichlet boundary condition for velocity
noslip = Constant((0, 0))
bc0 = DirichletBC(W.sub(0), noslip, 1)
bc1 = DirichletBC(W.sub(0), noslip, 2)
bc2 = DirichletBC(W.sub(0), noslip, 3)
bc3 = DirichletBC(W.sub(0), noslip, 4)

# Define test functions
chi, theta, omega, psi = TestFunctions(W)

# Define functions
u   = Function(W) 	# current solution
u0  = Function(W)	# solution from previous converged step

# Split mixed functions
v, p, phi, mu  = split(u)
v0, p0, phi0, mu0 = split(u0)

# Compute the chemical potential df/dphi
phi0 = variable(phi0)
#f    = 0.25*(phi**2 - 1)**2
#dfdphi = diff(f, phi)

class FreeEnergyPotential(Expression):
	def eval(self, values, phi0):
		if phi0 > 1:
			values[:] = (phi0-1)**2
		elif phi0 < -1:
			values[:] = (phi0+1)**2
		else:
			values[:] = 0.25*(phi0**2 -1)**2
	
f= Function(P1).interpolate(FreeEnergyPotential())
dfdphi = diff(f, phi0)

#class DfDphi(Expression):
#	def eval(self, values, phi0):
#		if phi > 1:
#			values[:] = (phi0-1)*2
#		elif phi < -1:
#			values[:] = (phi0+1)*2
#		else:
#			values[:] = phi0*(phi0**2 -1)


# create a functions for the density rho
rho = Function(P1)			# current
rho0 = Function(P1)			# previous timestep
# test function for density
eta = TestFunction(P1)

rho_hat = 0.5*((1 + phi0) + (rho2/rho1)*(1 - phi0))
	
# initialise density
rho0.interpolate(0.5*((1 + phi0) + (rho2/rho1)*(1 - phi0)))

ex, ey = unit_vectors(2)
# Weak statement of the equations
a1 = (1/dt)*omega*phi*dx + omega*div(v*phi0)*dx + (1/Pe)*inner(grad(omega),grad(mu))*dx
a2 = psi*mu*dx - (beta/Cn)*phi*psi*dx - Cn*inner(grad(phi),grad(psi))*dx - (1/rho_hat)*drho_hat*psi*p*dx 
a3 = rho0*(1/dt)*inner(chi, v)*dx + rho0*inner(dot(v0,nabla_grad(v)), chi)*dx +(1/We)*(-p*div(chi) + inner(chi, phi*grad(mu))-(1/rho_hat)*drho_hat*p*inner(chi, grad(phi0))  )*dx + (1/Re)*inner((nabla_grad(v)[i,j] + nabla_grad(v)[j,i]), nabla_grad(chi)[i,j] )*dx - (2/(3*Re))*div(v)*div(chi)*dx
a4 = theta*div(v)*dx + (alpha/Pe)*inner(grad(theta), grad(mu))*dx

a = derivative(a1 + a2 + a3 + a4, u)

L1 =  (1/dt)*omega*phi0*dx 
L2 = - (beta/Cn)*phi0*psi*dx + (1/Cn)*dfdphi*psi*dx 
L3 = rho0*(1/dt)*inner(chi, v0)*dx
if gravity:
    L3 += -(1/Fr2)*rho_hat*inner(chi, ey)*dx

L = L1 + L2 + L3


v0, p0, phi0, mu0 = u0.split()
# Initial conditions for phi 
#phi_ini = Expression('''1 - tanh( ( sqrt( pow(x[0]-x0, 2)+pow(x[1]-y0, 2) ) - r0 )/C*sqrt(2) )- tanh( ( sqrt( pow(x[0]-x1, 2)+pow(x[1]-y1, 2) ) - r1 )/C*sqrt(2) )''', x0=0.4,y0=0.5,r0 = 0.25,x1=0.78,y1=0.5,r1 = 0.1,C=0.01 )  
#phi_ini = Expression(''' - tanh( ( sqrt( pow(x[0]-x0, 2)+pow(x[1]-y0, 2) ) - r0 )/C*sqrt(2) )''', x0=0.5,y0=0.9,r0 = 0.1,C=0.03 ) 
#phi_ini = Expression('''1 - tanh( ( sqrt( pow(x[0]-x0, 2)+pow(x[1]-y0, 2) ) - r0 )/C*sqrt(2) )- tanh( ( sqrt( pow(x[0]-x1, 2)+pow(x[1]-y1, 2) ) - r1 )/C*sqrt(2) )''', x0=0.5,y0=0.65,r0 = 0.15,x1=0.5,y1=0.35,r1 = 0.15,C=0.0625 )  
phi_ini = Expression("x[1] <  (0.5) ? 1 : -1")
phi0.interpolate(phi_ini)

# Initial conditions for velocity
vel_ini = Expression(("0.0", "0.0"))
v0.interpolate(vel_ini)


### Weak equation for density
b1 = (1/dt)*eta*rho*dx
b = derivative(b1, rho)
M = (1/dt)*eta*rho0*dx - eta*div(rho0*v0)*dx

### sanity checks (all of these should be conserved)
order_parameter = phi*dx
total_mass = rho*dx
total_mass2 = rho_hat*dx

outfile = File("./Results_NSCH5/nsch_res.pvd")
t = 0
tend = 1.0
v, p, phi, mu = u0.split()
v.rename("Velocity")
p.rename("Pressure")
phi.rename("Order parameter")
mu.rename("Chemical potential")


outfile.write(v, p, phi, mu, time=t)

parameters["matnest"] = False

while t <= tend:
	t+=dt ; print t
	# Solve the system of equations	
	solve(a == L, u, bcs=[bc0, bc1, bc2, bc3],  solver_parameters={'pc_type': 'lu', 'pc_factor_mat_solver_package': 'mumps'}) #
	# update rho
	solve(b == M, rho, solver_parameters={'pc_type': 'lu', 'pc_factor_mat_solver_package': 'mumps'})
	rho0.assign(rho)
	u0.assign(u)	
	print assemble(order_parameter), assemble(total_mass), assemble(total_mass2)
	outfile.write(v, p, phi, mu, time=t)

