#=================================================
#
#  3 component reactions kernel of the form:
#
#    aA + bB <=> cC
#
#  where a, b, c are the stoichiometric coeffs
#  and A, B, C are the primary/secondary species
#
#=================================================
reactions3_kernel = op2.Kernel("""
/* Global variables */
PetscScalar jacValues[9] = {1,0,0,0,1,0,0,0,1};
PetscScalar psiValues[2] = {0,0};
PetscScalar keq[1] = {0};
PetscInt    idxm[3] = {0,1,2};
PetscInt    problem_type;

/* Jacobian */
PetscErrorCode FormJacobian(SNES snes, Vec x, Mat jac, Mat B, void *ctx) {
  const PetscScalar *xx;
  PetscFunctionBegin;
  VecGetArrayRead(x,&xx);
  switch (problem_type) {
    case 1:
    {
      jacValues[0] = -1;
      jacValues[1] = 0;
      jacValues[2] = -1;

      jacValues[3] = 0;
      jacValues[4] = -1;
      jacValues[5] = -1;
      
      jacValues[6] = -keq[0]*xx[1];
      jacValues[7] = -keq[0]*xx[0];
      jacValues[8] = 1;
      break;
    }
  }
  VecRestoreArrayRead(x,&xx);
  MatSetValues(B,3,idxm,3,idxm,jacValues,INSERT_VALUES);
  MatAssemblyBegin(B,MAT_FINAL_ASSEMBLY);
  MatAssemblyEnd(B,MAT_FINAL_ASSEMBLY);
  if (jac != B) {
    MatAssemblyBegin(jac,MAT_FINAL_ASSEMBLY);
    MatAssemblyEnd(jac,MAT_FINAL_ASSEMBLY);
  }
  PetscFunctionReturn(0);
}

/* Residual */
PetscErrorCode FormResidual(SNES snes, Vec x, Vec f, void *ctx) {
  const PetscScalar *xx;
  PetscScalar       *ff;
  PetscFunctionBegin;
  VecGetArrayRead(x,&xx);
  VecGetArray(f,&ff);
  switch (problem_type) {
    case 1:
    {
      ff[0] = psiValues[0] - xx[0] - xx[2];
      ff[1] = psiValues[1] - xx[1] - xx[2];
      ff[2] = xx[2] - keq[0]*xx[0]*xx[1];
      break;
    }
  }
  VecRestoreArrayRead(x,&xx);
  VecRestoreArray(f,&ff);
  PetscFunctionReturn(0);
}

void solve_kernel(const double *psiA, const double *psiB, double *cA, double *cB, 
    double *cC, const double *k1) {
  SNES              snes;
  KSP               ksp;
  PC                pc;
  Mat               J;
  Vec               x,r,lb,ub;
  const int         numC = 3;
  const int         VI = %(VI)d;
  const PetscScalar *xx;
  PetscInt          i;
  PetscErrorCode    ierr;
  int argc = 0;
  char **argv = "";

  /* Initialize */
  PetscInitialize(&argc,&argv,NULL,NULL);
  problem_type = %(problem)d;
  SNESCreate(PETSC_COMM_SELF, &snes);
  MatCreateSeqDense(PETSC_COMM_SELF, numC, numC, NULL, &J);
  MatSetFromOptions(J);
  VecCreateSeq(PETSC_COMM_SELF, numC, &x);
  VecDuplicate(x,&r);

  /* Initialize values */
  VecSet(x,1);
  psiValues[0] = psiA[0];
  psiValues[1] = psiB[0];
  keq[0] = k1[0];

  /* Solver */
  SNESSetFunction(snes,r,FormResidual,NULL);
  SNESSetJacobian(snes,J,J,FormJacobian,NULL);
  SNESGetKSP(snes,&ksp);
  KSPGetPC(ksp,&pc);
  PCSetType(pc,PCNONE);
  if (VI) {
    VecDuplicate(x,&lb);
    VecDuplicate(x,&ub);
    VecSet(lb,0);
    VecSet(ub,2);
    SNESSetType(snes,SNESVINEWTONRSLS);
    SNESVISetVariableBounds(snes,lb,ub);
  }
  SNESSetFromOptions(snes);
  SNESSolve(snes,NULL,x);

  /* Store solution */
  VecGetArrayRead(x,&xx);
  cA[0] = xx[0];
  cB[0] = xx[1];
  cC[0] = xx[2];
  VecRestoreArrayRead(x,&xx);

  /* Destroy */
  SNESDestroy(&snes);
  MatDestroy(&J);
  VecDestroy(&x);
  VecDestroy(&r);
  if (VI) {VecDestroy(&lb);}
  PetscFinalize();
}

""" % {'VI': opt_VI,'problem':problem},"solve_kernel")
