# IMPORT LIBRARIES

from firedrake import *
from firedrake.petsc import PETSc
import numpy as np
import ufl, os

import matplotlib.pyplot as plt

# DIFFERENTIAL OPERATORS

gradperp = lambda i: as_vector((-i.dx(1),i.dx(0)))

def steadylinearQG2d(Ztrial, Ztest, Winds, bc):
    
    # Split test and trial functions
    (psi1, u1, q1, psi2, u2, q2) = split(Ztrial)
    (phi1, v1, p1, phi2, v2, p2) = split(Ztest)
    
    # Create solution space
    soln = Function(Z2, name="Linear Solution")
    
    # Define weak form
    G =   ( phi1*q1 + inner(grad(phi1), grad(psi1)) + F1*phi1*(psi1-psi2) )*dx \
        + ( phi2*q2 + inner(grad(phi2), grad(psi2)) - F2*phi2*(psi1-psi2) )*dx \
        + ( inner(v1,u1) - inner(v1,gradperp(psi1)) )*dx \
        + ( inner(v2,u2) - inner(v2,gradperp(psi2)) )*dx \
        + ( beta*p1*psi1.dx(0) + nu1*inner(grad(p1),grad(q1)) )*dx \
        + ( beta*p2*psi2.dx(0) + r*p2*q2 + nu2*inner(grad(p2),grad(q2)) )*dx \
        - ( Winds*phi1 )*dx

    # Build Elliptic problem
    A = assemble(rhs(G))
    b = assemble(lhs(G))
    linear_problem = LinearVariationalProblem(lhs(G), rhs(G), soln, bcs=bc)
    linear_solver = LinearVariationalSolver(linear_problem, 
        solver_parameters={
                'ksp_type':'preonly',
                'ksp_monitor': True,
                'matnest': False,
                'mat_type': 'aij',
                'pc_type': 'lu'
                                })

    # Solve problem
    linear_solver.solve()
        
    return soln

# GEOMETRY

n0   = 50
Ly   = 1.0
Lx   = np.sqrt(2)
mesh = RectangleMesh(n0, n0, Lx, Ly,  reorder = None) 


# FUNCTION AND VECTOR SPACES

V1 = FunctionSpace(mesh, "CG", 3)  # streamfunction
V2 = FunctionSpace(mesh, "BDM", 2) # velocity
V3 = FunctionSpace(mesh, "DG", 1)  # potential vorticity
Z2 = V1 * V2 * V3 * V1 * V2 * V3
#Z2 = V1 * V2 * V3

# TEST/TRIAL FUNCTIONS
Ztrial = TrialFunction(Z2)
Ztest  = TestFunction(Z2)

# BOUNDARY CONDITIONS

bc1 = DirichletBC(Z2.sub(0), 0.0, "on_boundary")
bc2 = DirichletBC(Z2.sub(3), 0.0, "on_boundary")
bcs = [bc1, bc2]

# MODEL PARAMETERS

#alpha   = Constant("10.0")
beta    = Constant("1.0")
F1      = Constant("1.0")
F2      = Constant("1.0")
nu1     = Constant("0.0")
nu2     = Constant("0.0")
r       = Constant("0.05")
tau     = Constant("0.001")
Gwinds  = Function(V1).interpolate(Expression("-tau*cos(pi*(x[1]-0.5))", tau=tau))

#plot(Gwinds)
#plt.show()

soln0 = steadylinearQG2d(Ztrial, Ztest, Gwinds, bcs)

"""
psi1_0, u1_0, q1_0, psi2_0, u2_0, q2_0 = soln0.split()

print 'upper limits: [%g, %g]' % (psi1_0.dat.data.max(), psi1_0.dat.data.min())

# PLOT SOLUTION
plot(psi1_0)
plt.xlabel('Zonal')
plt.ylabel('Meridional')
plt.title('Linear Stommel Solution - Upper')
plt.xlim([0,Lx])
plt.ylim([0,Ly])
plt.show()

print 'lower limits: [%g, %g]' % (psi2_0.dat.data.max(), psi2_0.dat.data.min())

plot(psi2_0)
plt.xlabel('Zonal')
plt.ylabel('Meridional')
plt.title('Linear Stommel Solution - Lower')
plt.xlim([0,Lx])
plt.ylim([0,Ly])
plt.show()
"""

