# IMPORT LIBRARIES

from firedrake import *
from firedrake.petsc import PETSc
import numpy as np
import ufl, os

import matplotlib.pyplot as plt

# DIFFERENTIAL OPERATORS

gradperp = lambda i: as_vector((-i.dx(1),i.dx(0)))

def steadylinearQG2d(ZTrial, ZTest, Winds, bcs):
    
    psi1, u1, q1 = split(ZTrial)
    phi1, v1, p1 = split(ZTest)
    
    # Create solution space
    soln = Function(Z, name="Linear Solution")
    
    # Define weak form
    G =   ( phi1*q1 + inner(grad(phi1), grad(psi1)) + F1*phi1*psi1 )*dx \
        + ( inner(v1,u1) - inner(v1,gradperp(psi1)) )*dx \
        + ( beta*p1*psi1.dx(0) + nu*inner(grad(p1),grad(q1)) + r*p1*q1 )*dx \
        - ( Winds*p1 )*dx

    # Build Elliptic problem
    linear_problem = LinearVariationalProblem(lhs(G), rhs(G), soln, bcs=bcs)
    linear_solver = LinearVariationalSolver(linear_problem, 
        solver_parameters={
                'ksp_type':'preonly',
                'ksp_monitor': True,
                'matnest': False,
                'mat_type': 'aij',
                'pc_type': 'lu'
                                })

    # Solve problem
    linear_solver.solve()
        
    return soln

# GEOMETRY

n0   = 50
Ly   = 1.0
Lx   = np.sqrt(2)
mesh = RectangleMesh(n0, n0, Lx, Ly,  reorder = None) 

# FUNCTION AND VECTOR SPACES

V1 = FunctionSpace(mesh, "CG", 3)  # streamfunction
V2 = FunctionSpace(mesh, "BDM", 2) # velocity
V3 = FunctionSpace(mesh, "DG", 1)  # potential vorticity
Z = V1 * V2 * V3

# TEST/TRIAL FUNCTIONS
ZTrial = TrialFunction(Z)
ZTest = TrialFunction(Z)

# BOUNDARY CONDITIONS

bc1 = DirichletBC(Z.sub(0), 0.0, "on_boundary")
bcs = [bc1]

# MODEL PARAMETERS

beta    = Constant("1.0")
F1      = Constant("1.0")
F2      = Constant("1.0")
nu      = Constant("0.0")
r       = Constant("0.05")
tau     = Constant("0.001")
Gwinds  = Function(V1).interpolate(Expression("-tau*cos(pi*(x[1]-0.5))", tau=tau))

# plot(Gwinds)
# plt.show()

# GET SOLUTION

soln0 = steadylinearQG2d(ZTrial, ZTest, Gwinds, bcs)

psi1_0, u1_0, q1_0 = soln0.split()

print 'upper limits: [%g, %g]' % (psi1_0.dat.data.max(), psi1_0.dat.data.min())

# PLOT SOLUTION
plot(psi1_0)
plt.xlabel('Zonal')
plt.ylabel('Meridional')
plt.title('Linear Stommel Solution - Upper')
plt.xlim([0,Lx])
plt.ylim([0,Ly])
plt.show()

print 'lower limits: [%g, %g]' % (psi2_0.dat.data.max(), psi2_0.dat.data.min())

plot(psi2_0)
plt.xlabel('Zonal')
plt.ylabel('Meridional')
plt.title('Linear Stommel Solution - Lower')
plt.xlim([0,Lx])
plt.ylim([0,Ly])
plt.show()
