#from firedrake import *
reactions_kernel = op2.Kernel("""
/* Application Context */
typedef struct {
  PetscScalar *jacValues,*psi,*keq;
  PetscInt problem;
} AppCtx;

/* Jacobian */
static void FormJacobian(SNES snes, Vec x, Mat jac, Mat B, void *ctx) {
  AppCtx *user = (AppCtx*)ctx;
  const PetscScalar *xx;
  VecGetArrayRead(x,&xx);
  switch (user->problem) {
    case 1:
    {
      user->jacValues[0] = -1;
      user->jacValues[1] = 0;
      user->jacValues[2] = -1;

      user->jacValues[3] = 0;
      user->jacValues[4] = -1;
      user->jacValues[5] = -1;
      
      user->jacValues[6] = -user->keq[0]*xx[1];
      user->jacValues[7] = -user->keq[0]*xx[0];
      user->jacValues[8] = 1;
      break;
    }
    default:
    { 
      exit(1);
    }
  }
  VecRestoreArrayRead(x,&xx);
}

/* Residual */
static void FormResidual(SNES snes, Vec x, Vec f, void *ctx) {
  AppCtx *user = (AppCtx*)ctx;
  const PetscScalar *xx;
  PetscScalar       *ff;
  VecGetArrayRead(x,&xx);
  VecGetArray(f,&ff);
  switch (user->problem) {
    case 1:
    {
      ff[0] = user->psi[0] - xx[0] - xx[1];
      ff[1] = user->psi[1] - xx[1] - xx[2];
      ff[2] = xx[2] - user->keq[0]*xx[0]*xx[1];
      break;
    }
    default:
    { 
      exit(1);
    }
  }
  VecRestoreArrayRead(x,&xx);
  VecRestoreArray(f,&ff);
}

void solve_kernel(const double *psiVar, double *cVar, const double *kVar) {
  SNES              snes;
  KSP               ksp;
  PC                pc;
  Mat               J;
  Vec               x,r,lb,ub;
  AppCtx            user;
  const int         numC = %(numC)d;
  const int         numPsi = %(numPsi)d;
  const int         numKeq = %(numKeq)d;
  const int         VI = %(VI)d;
  const int         problem= %(problem)d;
  const PetscScalar *xx;
  PetscInt          i;
  PetscErrorCode ierr;

  /* Initialize */
  SNESCreate(PETSC_COMM_SELF, &snes);
  MatCreateSeqDense(PETSC_COMM_SELF, numC, numC, NULL, &J);
  MatSetFromOptions(J);
  VecCreateSeq(PETSC_COMM_SELF, numC, &x);
  VecDuplicate(x,&r);
  PetscCalloc3(numC*numC,&user.jacValues,numPsi,&user.psi,numKeq,&user.keq);
  user.problem = problem;

  /* Initialize values */
  VecSet(x,1);
  for (i = 0; i < numPsi; i++)
    user.psi[i] = psiVar[i];
  for (i = 0; i < numKeq; i++)
    user.keq[i] = kVar[i];

  /* Solver */
  SNESSetFunction(snes,r,FormResidual,(void*)&user);
  SNESSetJacobian(snes,J,J,FormJacobian,(void*)&user);
  SNESGetKSP(snes,&ksp);
  KSPGetPC(ksp,&pc);
  PCSetType(pc,PCNONE);
  if (VI) {
    VecDuplicate(x,&lb);
    VecDuplicate(x,&ub);
    VecSet(lb,0);
    VecSet(ub,99);
    SNESSetType(snes,SNESVINEWTONRSLS);
    SNESVISetVariableBounds(snes,lb,ub);
  }
  SNESSetFromOptions(snes);
  SNESSolve(snes,NULL,x);

  /* Store solution */
  VecGetArrayRead(x,&xx);
  for (i = 0; i < numC; i++) 
    cVar[i] = xx[i];
  VecRestoreArrayRead(x,&xx);

  /* Destroy */
  SNESDestroy(&snes);
  MatDestroy(&J);
  VecDestroy(&x);
  VecDestroy(&r);
  if (VI) {VecDestroy(&lb);}
  PetscFree3(user.jacValues,user.psi,user.keq);
}

""" % {'numPsi':W.dof_dset.cdim,'numC':W_R.dof_dset.cdim,'numKeq':numKeq,'VI': \
    opt_VI,'problem':problem},"solve_kernel")
