
'''

Test QG Equation in Firedrake

'''

# Import


from __future__ import division # Get proper divison

import itertools
import numpy as np
import matplotlib.pyplot as plot
import math
import random 
from scipy import integrate as int1
from scipy.stats import gaussian_kde
from scipy import stats
from mpl_toolkits.mplot3d import Axes3D
import pickle
from scipy.stats import norm
import time
import scipy.sparse
from pulp import *
import scipy.sparse as sp
from matplotlib import cm

from firedrake import *
parameters["pyop2_options"]["log_level"] = "WARNING"

n0 = 5
dx0 = 1.0/n0


# Setting up a quadrilateral mesh - suitable for make_c_evaluator
mesh=RectangleMesh(n0,n0,1,1,quadrilateral=False)


Vdg = FunctionSpace(mesh,"DG",1)
Vcg = FunctionSpace(mesh,"CG",1)
Vu = VectorFunctionSpace(mesh,"DG",1)

nxi = 3
Xis = []
for i in range(nxi):
    Xis.append(Function(Vcg))

Xis[0].interpolate(Expression("sin(pi*x[0])*sin(pi*x[1])"))
Xis[1].interpolate(Expression("sin(2*pi*x[0])*sin(pi*x[1])"))
Xis[2].interpolate(Expression("sin(pi*x[0])*sin(pi*2*x[1])"))

#relative PV
q0 = Function(Vdg).interpolate(Expression("pow(x[0]-0.5,2)+pow(x[1]-0.25,2)< 0.01 ? 1.0 : (pow(x[0]-0.5,2)+pow(x[1]-0.75,2) < 0.01 ? -1.0 : 0.0)"))

dq1 = Function(Vdg)
qh = Function(Vdg)
q1 = Function(Vdg)

psi0 = Function(Vcg)
psi1 = Function(Vcg)

#Some coefficients
F = Constant(1.0) #Rotational Froude number
beta = Constant(0.1) #beta plane coefficient
#Aiming for Courant number < 0.3, assume |u|=1
# 0.3 > 1*dt/dx => dt < dx*0.3
Dt = 0.1*dx0
dt = Constant(Dt)

#set up PV inversion


psi = TrialFunction(Vcg)
phi = TestFunction(Vcg)

Apsi = (F*psi*phi + inner(grad(psi),grad(phi)) - beta*phi*psi.dx(1))*dx
Lpsi = -q1*phi*dx

# new boundary conditions

bc1 = [DirichletBC(Vcg, 0., 1),DirichletBC(Vcg, 0., 2),DirichletBC(Vcg,0.,3),DirichletBC(Vcg,0.,4)]

psi_problem = LinearVariationalProblem(Apsi,Lpsi,psi0,bcs=bc1)
psi_solver = LinearVariationalSolver(psi_problem,
                                     solver_parameters={
        'ksp_type':'cg',
        'pc_type':'sor'
        })

#make a gradperp
gradperp = lambda u: as_vector((-u.dx(1), u.dx(0)))

#set up PV advection

# Mesh-related functions
n = FacetNormal(mesh)

# ( dot(v, n) + |dot(v, n)| )/2.0
un = 0.5*(dot(gradperp(psi0), n) + abs(dot(gradperp(psi0), n)))

# advection equation
q = TrialFunction(Vdg)
p = TestFunction(Vdg)
a_mass = p*q*dx
a_int = (dot(grad(p), -gradperp(psi0)*q))*dx
a_flux = ( dot(jump(p), un('+')*q('+') - un('-')*q('-')) )*(dS)
arhs = a_mass-dt*(a_int + a_flux)

q_problem = LinearVariationalProblem(a_mass, action(arhs,q1), dq1)
q_solver = LinearVariationalSolver(q_problem, 
                                   solver_parameters={
        'ksp_type':'cg',
        'pc_type':'sor'
        })


qfile = File("q.pvd")
qfile << q0
psifile = File("psi.pvd")
psifile << psi0
vfile = File("v.pvd")
v = Function(Vu).project(gradperp(psi0))
vfile << v

t = 0.
T = 0.1
dumpfreq = 100
tdump = 0

v0 = Function(Vu)

while(t < (T-Dt/2)):
    #Predictor stage
    q1.assign(q0)
    psi_solver.solve() #deterministic part of streamfunction
    q_solver.solve()
    q1.assign(dq1)
    #Corrector stage
    psi_solver.solve() #deterministic part of streamfunction
    q_solver.solve()
    q1.assign((3/4)*q0 + (1/4)*dq1)
    # third SSP step
    psi_solver.solve() #deterministic part of streamfunction
    q_solver.solve()
    q0.assign((1/3)*q0 + (2/3)*dq1)
    # Store solutions to xml and pvd
    t +=Dt
    print t
    tdump += 1
    if(tdump==dumpfreq):
        tdump -= dumpfreq
        qfile << q0
        psifile << psi0
        v.project(gradperp(psi0))
        vfile << v










