from __future__ import division
from firedrake import *
import numpy as np

class Configuration(object):

    def __init__(self, **kwargs):
        for name, value in kwargs.iteritems():
            self.__setattr__(name, value)

    def __setattr__(self, name, value):
        """Cause setting an unknown attribute to be an error"""
        if not hasattr(self, name):
            raise AttributeError("'%s' object has no attribute '%s'" % (type(self).__name__, name))
        object.__setattr__(self, name, value)
		
class MyParameters(Configuration):
	# Model parameters
	dt = 1.0e-02 		# time step
	Cn=0.03			# Cahn
	We=1			# Weber
	Pe=20000		# Peclet
	Re=1			# Reynolds
	gravity = False		# switch for gravity 
	Fr2=0.1			# Froude number squared
	beta = 2		# stabilising term

	rho1 = Constant(1.0)
	rho2 = Constant(1.0)
	drho_hat = 0.5*(1 - (rho2/rho1))
	alpha = (rho2-rho1)/(rho2+rho1)

class phiBC(DirichletBC):
	def __init__(self, V, value, sid):
		super(phiBC,self).__init__(V,value,sid)
		self.nodes = np.arange(V.dof_dset.total_size, dtype=np.int32)

myparameters = MyParameters()

# Create mesh and define function spaces
m=64
mesh = UnitSquareMesh(m, m)
P1 = FunctionSpace(mesh, "Lagrange", 1)
P2 = VectorFunctionSpace(mesh, "Lagrange", 2)
n = FacetNormal(mesh)
# Create mixed space
W = P2*P1

# Define functions
u   = Function(W) 	# 
u0  = Function(W)	# lagged solution

# Dirichlet boundary condition for velocity
noslip = Constant((0, 0))
inlet = Function(P2).interpolate(Expression(("0.0","0.25*(x[0] + 1)*(x[0] - 1)")))
bc0 = DirichletBC(W.sub(0), inlet, 4)
bc1 = DirichletBC(W.sub(0), noslip, 2)	
bc2 = DirichletBC(W.sub(0).sub(0), 0, 1)

# Initial conditions
v0, p0 = u0.split()
vel_ini = Expression(("0.0", "0.0"))	# Initial conditions for velocity
v0.interpolate(vel_ini)


def F(u, u0, rho, rho0, rho_hat, W,n, myparameters):
	# access parameters
	Re=myparameters.Re
	
	# Define test functions
	chi, theta = TestFunctions(W)
	
	v, p  = split(u)
	v0, p0 = split(u0)
	
	f3 = inner(dot(v0,nabla_grad(v)), chi)*dx \
	+ (-p*div(chi) )*dx \
	+ (1/Re)*inner(nabla_grad(v), nabla_grad(chi) )*dx \
	+ (1/Re)*inner(nabla_grad(v)[j,i], nabla_grad(chi)[i,j])*dx - (2/(3*Re))*div(v)*div(chi)*dx \
	- (1/Re)*inner(inner(chi[i], nabla_grad(v)[i,j]), n[j])*(ds(1)+ds(4)+ds(3)) \

	f4 = theta*div(v)*dx 
	
	F = f3 + f4
	return F
	

v, p = u0.split()
v.rename("Velocity")
p.rename("Pressure")

outfile = File("./Results/nsch_res.pvd")

parameters["matnest"] = False

exact = Function(P2).interpolate(inlet)
print "L2norm = ", sqrt(assemble(inner(v0-exact, v0-exact)*dx))

rho1 = myparameters.rho1
rho2 = myparameters.rho2
F_u = F(u,u, rho1, rho1, rho1, W, n, myparameters)
t = 0.0
# Solve the system of equations	
solve(F_u == 0, u, bcs = [bc0,bc1,bc2],solver_parameters={'pc_type': 'lu', 'pc_factor_mat_solver_package': 'mumps'} ) # 
u0.assign(u)
v, p= u0.split()		
outfile.write(v, p, time=t)
	
print "L2norm = ", sqrt(assemble(inner(v-exact,v-exact)*dx))

