from firedrake import *
import numpy
import matplotlib.pyplot as plt

# -------------------------------------
# Mesh
# -------------------------------------

nbx = 2
nby = 2
lx = 0.5
ly = 0.5
dim =2

mesh = RectangleMesh(nbx, nby, lx, ly, quadrilateral=False)
# plt.show(plot(mesh))

# -------------------------------------
# Approximation Spaces
# -------------------------------------

V = VectorFunctionSpace(mesh, "Lagrange", 2)
W = FunctionSpace(mesh, "Lagrange", 1)
X = FunctionSpace(mesh, "Lagrange", 1)
Z = V * W * X

u, p, T = TrialFunctions(Z)
v, q, U = TestFunctions(Z)

u_prev = Function(V)
p_prev = Function(W)
T_prev = Function(X)

# -------------------------------------
# Material Parameters
# -------------------------------------
# Mechanics
E = 6.E2
nu = 0.3
lmbda = Constant(E*nu/(1+nu)/(1-2*nu))
mu = Constant(E/2/(1+nu))
rho_m = Constant(1000)

# Hydraulics
Compress = Constant(1./5.e-3)
Poro = Constant(0.18)
c0 = Constant(Poro / Compress)
b = Constant(1.)
Perm = Constant(4.e-2)

Visc = Constant(0.1)

Q = Constant(0)

# Thermics
Cp_h = Constant(4) # fluid voluminal heat
Cp_m = Constant(10) # fluid voluminal heat
k_T = Constant(1.) # thermal conductivity
alpha_h = Constant(1.E-2)   # fluid expansion
alpha_m = Constant(1.E-1)  # solid expansion
h = Constant(0)   # liquid mass enthalpy
lambda_T = Constant(1.6)   # effective conductivity
K = Constant(lmbda + 2 * mu / 3)   # solid compressibility modulus
rho_h = Constant(8)   # fluid density
T0 = Constant(273)   # reference temperature

rho_s = Constant((rho_m - Poro*rho_h) / (1 - Poro))
C_sigm = Constant( (1 - Poro)*rho_s*Cp_m + Cp_h*rho_h*Poro  )
C_eps = Constant( C_sigm - 9.*T0*K*alpha_m*alpha_m  )   # voluminal heat at \epsilon constant
k_h = Constant( Perm / Visc )  # hydraulic permeability
c_T = Constant(alpha_h * Poro) # αh*porosity

# -------------------------------------
# Time Parameters
# -------------------------------------

t = 0.0
dt = 1.
t += dt
Tfin = 1.

# -------------------------------------
# Variational Formulation
# -------------------------------------


def sigma(v):
    return 2.0*mu*sym(grad(v)) + lmbda*tr(grad(v))*Identity(dim)

def A(u, v):
    return inner(sigma(u), grad(v))

def B(v, q):
    return q * div(v)

def C(p, q):
    return p * q

def D(p,q):
    return inner(grad(q), grad(p))

z = Constant(3*K*T0)
w = Constant(h*rho_h)

a = (A(u, v) - B(b*v, p) \
   - B(b*u, q) - C(c0*p, q) - C(c_T*T, q) + dt*D(k_h*p, q) + \
    B((z*alpha_m+w)*u, U) + C((w*c0-3*alpha_h*T0)*p, U) + C((C_eps-w*c_T)*T, U) + dt*D(k_T*T,U) + dt*D(w*k_h*p,U)) * dx


L = (B(b*u_prev, q) + C(c0*p_prev, q) - C(c_T*T_prev, q) + \
    B((z*alpha_m+w)*u_prev, U) + C((w*c0-3*alpha_h*T0)*p_prev, U) + C((C_eps - w*c_T)*T_prev, U)) * dx

# -------------------------------------
# Time Dependent BC
# -------------------------------------

tt = Constant(t)

bcs = [DirichletBC(Z.sub(0).sub(1), 0.2*tt, (4,)),
       DirichletBC(Z.sub(0), Constant((0, 0)), (3, ))]


u = Function(Z)



print("Solving with Krylov solver and block preconditioner")
parameters = {
    "ksp_type": "fgmres",
    "ksp_rtol": 1.e-12,
    "ksp_monitor": None,
    # "mat_type":"aij",     # the script runs fine if a MatAij is used
    "pc_type": "fieldsplit",
    "pc_fieldsplit_type": "multiplicative",
    "fieldsplit_0_ksp_type": "fgmres",
    "fieldsplit_0_pc_type": "none",
    "fieldsplit_0_ksp_monitor" : None,
    "fieldsplit_1_ksp_type": "fgmres",
    "fieldsplit_1_ksp_monitor" : None,
    "fieldsplit_1_pc_type": "jacobi",
    "fieldsplit_1_aux_pc_type": "none",
    "fieldsplit_2_ksp_type": "fgmres",
    "fieldsplit_2_pc_type": "none",
    "fieldsplit_2_ksp_monitor" : None,
}

# =====================
# Time loop
V1 = VectorFunctionSpace(mesh, "Lagrange", 1)


def output_data(U, P, T):
    mesh_static = mesh.coordinates.vector().get_local()
    mesh.coordinates.vector().set_local(mesh_static + project(U, V1, name="Velocity").vector().get_local())
    plt.show(plot(P, contour=False))
    plt.show(plot(T, contour=False))
    mesh.coordinates.vector().set_local(mesh_static)


X = Function(Z)
thm_problem = LinearVariationalProblem(a, L, X, bcs=bcs)
thm_solver = LinearVariationalSolver(thm_problem, solver_parameters=parameters)


while t <= Tfin:
    print("Time step %f" % tt)

    thm_solver.solve()

    u_, p_ , T_= X.split()
    u_prev.assign(u_)
    p_prev.assign(p_)
    T_prev.assign(T_)

    t += dt
    tt.assign(t)

    # output_data(u_prev,p_prev, T_prev)
    print(u_.vector().array())
    print(p_.vector().array())
    print(T_.vector().array())
