import petsc4py
petsc4py.init('-log_view')
import time
from firedrake import *
from petsc4py import PETSc
from mpi4py import MPI
import argparse
parser = argparse.ArgumentParser(
            description='Script for performing scaling analysis of the \
            firedrake library. The poisson equation is solved over a \
            unit cube and run time of the following quantities is \
            measured - a) matrix assembly, b) rhs assembly, \
            and c) the linear solve')
parser.add_argument('--n', help='The problem size per dimension per core (default = 32)', \
                    type=int, default=32)
parser.add_argument('--scaling_type', dest="scaling_type", default='strong', 
                    help="Scaling type, \'weak\' or \'strong\' (default)")
parser.add_argument('--poly_degree', type=int, default=2, 
                    help='The degree of the interpolating polynomial \
                    (default = 2)')
parser.add_argument('--save_timings', action='store_true', 
                    help='Save timings or not (default : False)')

args = parser.parse_args()


dim = 3

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

n            = args.n
scaling_type = args.scaling_type
poly_degree  = args.poly_degree
save_timings = args.save_timings

filename = 'poisson_dim_%d_n_%d_deg_%d_%s.time' % \
               (dim, n, poly_degree, scaling_type)

if scaling_type == 'weak':
    n = int(n * size ** (1.0 / dim))

###############################################################################
# Modify PyOP2 parameters
###############################################################################
parameters['pyop2_options']['compiler'] = 'intel'
parameters['pyop2_options']['blas']     = 'mkl'
parameters['pyop2_options']['simd_isa'] = 'avx'
parameters['pyop2_options']['cflags']   = '-O3'
parameters["coffee"]["licm"] = True
parameters["coffee"]["ap"] = True

###############################################################################
# Set up poisson problem
###############################################################################
mesh = UnitCubeMesh(n, n, n)

V = FunctionSpace(mesh, "Lagrange", poly_degree)

dofs_per_core = V.dof_dset.size

bc = DirichletBC(V, 0.0, [3, 4])
u = TrialFunction(V)
v = TestFunction(V)
f = Function(V).interpolate(Expression(
    "48*pi*pi*cos (4*pi*x[0])*sin (4*pi*x[1])*cos (4*pi*x[2])"))
a = inner(grad(u), grad(v)) * dx
L = f * v * dx
u = Function(V)

###############################################################################
# Begin matrix assembly
###############################################################################
mat_assembly_start = MPI.Wtime()

A = assemble(a, bcs=bc)

mat_assembly_end = MPI.Wtime()

mat_assembly_time = mat_assembly_end - mat_assembly_start
###############################################################################
# End matrix assembly
###############################################################################


###############################################################################
# Begin RHS assembly
###############################################################################
rhs_assembly_start = MPI.Wtime()

b = assemble(L)
bc.apply(b)

rhs_assembly_end = MPI.Wtime()

rhs_assembly_time = rhs_assembly_end - rhs_assembly_start
###############################################################################
# End RHS assembly
###############################################################################


###############################################################################
# Begin solve
###############################################################################
solve_start = MPI.Wtime()

solve(A, u, b, solver_parameters = {'ksp_type': 'cg',
                                    'pc_type': 'hypre',
                                    'pc_hypre_type': 'boomeramg',
                                    'pc_hypre_boomeramg_strong_threshold': 0.75,
                                    'pc_hypre_boomeramg_agg_nl' : 2,
                                    'ksp_rtol': 1e-6,
                                    'ksp_atol': 1e-15})

solve_end = MPI.Wtime()

solve_time = solve_end - solve_start
###############################################################################
# End solve
###############################################################################
comm.Barrier()

if rank == 0:
    max_mat_assembly_time = 0.0
    max_rhs_assembly_time = 0.0
    max_solve_time        = 0.0
else:
    max_mat_assembly_time = None
    max_rhs_assembly_time = None
    max_solve_time        = None

#Reduce timings
max_mat_assembly_time = comm.reduce(mat_assembly_time,
                                    op = MPI.MAX,
                                    root = 0)

max_rhs_assembly_time = comm.reduce(rhs_assembly_time,
                                    op = MPI.MAX,
                                    root = 0)

max_solve_time = comm.reduce(solve_time, 
                             op = MPI.MAX, 
                             root = 0)


#Analytical solution

analytical = "cos(4*pi*x[0])*sin(4*pi*x[1])*cos(4*pi*x[2])"
a          = Function(V).interpolate(Expression(analytical))
l2         = sqrt(assemble(dot(u - a, u - a) * dx))

#Print timings
if rank == 0 and save_timings:
    print("NUMPROCS %d" % size)
    with open(filename, 'a') as f:
        out_list = [size, dofs_per_core, max_mat_assembly_time, \
                    max_rhs_assembly_time, max_solve_time]
        out_str  = ' '.join(str(i) for i in out_list)
        f.write(out_str + '\n')
        f.flush()


