from firedrake import *
import numpy as np
import matplotlib.pyplot as plt

# Geometric parameters
eps = 1./3.
Lx  = 3*eps
t   = Constant(1.e-2)
G0  = Constant(1.0)

v0 = Constant(0.0)
k0 = Constant(1.0)
f0 = Constant(0.0)

# Constitutive parameters
Y = Constant(1.)
nu = Constant(0.)
D = Y*t**3/(12.*(1 - nu**2))

# Mesh
#n = 500
mesh = IntervalMesh(100, 0., Lx)


# Discrete space
V  = FunctionSpace(mesh,"Hermite",3)
Qh = MixedFunctionSpace([V, V, V])
#Q = VectorFunctionSpace(mesh, "Hermite",3)

q_ = Function(Qh)
q  = TrialFunction(Qh)
q_t= TestFunction(Qh)

v_, k_, f_ = split(q_)
v, k, f = split(q)
v_t, k_t, f_t = split(q_t)

k0 = Constant(1.0)
f0 = Constant(0.0)


# Energy densities
delk_ = k_ - k0

psi_b = 0.5*((1/eps**2)*(delk_**2) + (1/eps)*(2*nu*delk_*v_.dx(0).dx(0))+ ((1./6.)*(1. - nu)*(delk_.dx(0))**2 + (v_.dx(0).dx(0))**2) + (eps**2)*(k_.dx(0).dx(0))**2/720.)

psi_m = 0.5*((1/eps**4)*(720*f_**2) + (1/eps**2)*(240./7.)*((1. + nu)*(f_.dx(0))**2 + nu*f_*f_.dx(0).dx(0)) + (10./7.)*(f_.dx(0).dx(0))**2)

psi_c = (1/eps)*(v_.dx(0)*(k_*f_).dx(0)) + (1./84.)*k_.dx(0)*(k_*f_).dx(0) + (3./2.)*(k_**2 - k0**2)*f_.dx(0).dx(0)

#  boundary conditions
class lHermiteV(DirichletBC):
      @utils.cached_property
      def nodes(self):
          if V.mesh().mpi_comm().size > 1: raise NotImplementedError
          return [0]    # the derivative component of the left  endpoint

class rHermiteV(DirichletBC):
      @utils.cached_property
      def nodes(self):
          if V.mesh().mpi_comm().size > 1: raise NotImplementedError
          return [2*n+1]  # the derivative component of the right endpoint

class lHermiteDV(DirichletBC):
      @utils.cached_property
      def nodes(self):
          if V.mesh().mpi_comm().size > 1: raise NotImplementedError
          return [1]    # the derivative component of the left  endpoint

class rHermiteDV(DirichletBC):
      @utils.cached_property
      def nodes(self):
          if V.mesh().mpi_comm().size > 1: raise NotImplementedError
          return [2*n]  # the derivative component of the right endpoint

bcv1 = lHermiteV(V, 0, "on_boundary")
bcv1d= lHermiteDV(V, 0, "on_boundary")
bcv2 = rHermiteV(V, 0, "on_boundary")
bcv2d= rHermiteDV(V, 0, "on_boundary")
bck1 = lHermiteV(V, 0, "on_boundary")
bck1d= lHermiteDV(V, 0, "on_boundary")
bck2 = lHermiteV(V, 0, "on_boundary")
bck2d= lHermiteDV(V, 0, "on_boundary")
bcf1= lHermiteV(V, 0, "on_boundary")
bcf1d=lHermiteDV(V, 0, "on_boundary")
bcf2= rHermiteV(V, 0, "on_boundary")
bcf2d=rHermiteDV(V, 0, "on_boundary")

bcs = [bcv1, bcv1d, bck1, bck1d, bcf1, bcf1d, bcf2, bcf2d]
#bcs = [bcv1, bcv1d, bcv2, bcv2d, bck1, bck1d, bck2, bck2d, bcf1, bcf1d, bcf2, bcf2d]

# Problem
L = D*psi_b*dx - 1./(t*Y)*psi_m*dx + psi_c*dx
F = derivative(L, q_, q_t) 
dF= derivative(F, q_, q_t)

# Solve

solve(F==0,q_,bcs)

# problem = NonlinearVariationalProblem(F, q_, bcs, J = dF)
# solver = NonlinearVariationalSolver(problem)
#solver.parameters.newton_solver.absolute_tolerance = abs_tol
#solve(F, q_, bcs, J=dF)
# solver.solve()
#sp = {"snes_type": "vinewtonrsls","snes_monitor": None}   
#problem = NonlinearVariationalProblem(F, q_, bcs, J =dF)
#solver = NonlinearVariationalSolver(problem, solver_parameters=sp)
#solver.solve()


