#include <deal.II/base/quadrature_lib.h>
#include <deal.II/base/function.h>

#include <deal.II/lac/vector.h>
#include <deal.II/lac/full_matrix.h>
#include <deal.II/lac/solver_cg.h>
#include <deal.II/lac/constraint_matrix.h>
#include <deal.II/lac/dynamic_sparsity_pattern.h>

#include <deal.II/lac/petsc_parallel_sparse_matrix.h>
#include <deal.II/lac/petsc_parallel_vector.h>
#include <deal.II/lac/petsc_solver.h>
#include <deal.II/lac/petsc_precondition.h>

#include <deal.II/grid/grid_generator.h>
#include <deal.II/grid/grid_tools.h>
#include <deal.II/grid/tria_accessor.h>
#include <deal.II/grid/tria_iterator.h>
#include <deal.II/dofs/dof_handler.h>
#include <deal.II/dofs/dof_accessor.h>
#include <deal.II/dofs/dof_tools.h>
#include <deal.II/fe/fe_values.h>
#include <deal.II/fe/fe_q.h>
#include <deal.II/fe/fe_nothing.h>
#include <deal.II/numerics/vector_tools.h>
#include <deal.II/numerics/data_out.h>

#include <deal.II/base/utilities.h>
#include <deal.II/base/conditional_ostream.h>
#include <deal.II/base/index_set.h>
#include <deal.II/lac/sparsity_tools.h>
#include <deal.II/distributed/tria.h>

#include <fstream>
#include <iostream>

namespace CrackedDomain { namespace aux {

  class AnalyticalSolution : public dealii::Function<2>
  {
    public:
      AnalyticalSolution()
        : dealii::Function<2>(1)
      {}

      virtual double
      value(const dealii::Point<2> & p, const unsigned int component=0) const;

      virtual dealii::Tensor<1,2>
      gradient(const dealii::Point<2> & p, const unsigned int component=0) const;

      virtual double
      laplacian(const dealii::Point<2> & p, const unsigned int component=0) const;

      virtual double
      forcing(const dealii::Point<2> & p, const unsigned int component=0) const;
  };

  class UpperCrack : public dealii::Function<2>
  {
    public:
      UpperCrack()
        : dealii::Function<2>(1)
      {}

      virtual double
      value(const dealii::Point<2> & p, const unsigned int component=0) const;
  };

  class LowerCrack : public dealii::Function<2>
  {
    public:
      LowerCrack()
        : dealii::Function<2>(1)
      {}

      virtual double
      value(const dealii::Point<2> & p, const unsigned int component=0) const;
  };

  class UpperEdge : public dealii::Function<2>
  {
    public:
      UpperEdge()
        : dealii::Function<2>(1)
      {}

      virtual double
      value(const dealii::Point<2> & p, const unsigned int component=0) const;
  };

  class LowerEdge : public dealii::Function<2>
  {
    public:
      LowerEdge()
        : dealii::Function<2>(1)
      {}

      virtual double
      value(const dealii::Point<2> & p, const unsigned int component=0) const;
  };

  double
  AnalyticalSolution::value(const dealii::Point<2> & p,
                            const unsigned int) const
  {
    const double x = p[0];
    const double y = p[1];

    if((x > 0) or (std::abs(x) < 1.e-14))
      return 0;

    if(y > 0)
      return x*x*std::exp(-1*x);

    if(y < 0)
      return -1*x*x*std::exp(-1*x);

    AssertThrow(false, dealii::ExcMessage("Error"));

    return 0;
  }


  dealii::Tensor<1,2>
  AnalyticalSolution::gradient(const dealii::Point<2> & p,
                               const unsigned int) const
  {
    dealii::Tensor<1,2> result;
    result = 0;

    const double x = p[0];
    const double y = p[1];

    if((x > 0) or (std::abs(x) < 1.e-14)) {
      return result;
    }

    if(y > 0) {
      result[0] = 2*x*std::exp(-1*x) - x*x*std::exp(-1*x);
      return result;
    }

    if(y < 0) {
      result[0] = -2*x*std::exp(-1*x) + x*x*std::exp(-1*x);
      return result;
    }

    AssertThrow(false, dealii::ExcMessage("Error"));

    return result;
  }


  double
  AnalyticalSolution::laplacian(const dealii::Point<2> & p,
                                const unsigned int) const
  {
    const double x = p[0];
    const double y = p[1];

    if((x > 0) or (std::abs(x) < 1.e-14))
      return 0;

    if(y > 0)
      return 2*std::exp(-1*x) - 4*x*std::exp(-1*x) + x*x*std::exp(-1*x);

    if(y < 0)
      return -2*std::exp(-1*x) + 4*x*std::exp(-1*x) - x*x*std::exp(-1*x);

    AssertThrow(false, dealii::ExcMessage("Error"));

    return 0;
  }


  double
  AnalyticalSolution::forcing(const dealii::Point<2> & p,
                              const unsigned int component) const
  {
    return -1 * laplacian(p, component);
  }


  double
  UpperCrack::value(const dealii::Point<2> & p,
                    const unsigned int) const
  {
    const double x = p[0];

    return x*x*std::exp(-1*x);
  }


  double
  LowerCrack::value(const dealii::Point<2> & p,
                    const unsigned int) const
  {
    const double x = p[0];

    return -1*x*x*std::exp(-1*x);
  }


  double
  UpperEdge::value(const dealii::Point<2> &,
                   const unsigned int) const
  {
    return std::exp(1);
  }


  double
  LowerEdge::value(const dealii::Point<2> &,
                   const unsigned int) const
  {
    return -1*std::exp(1);
  }

}}

namespace CrackedDomain
{
  using namespace dealii;

  class Problem
  {
  public:
    Problem ();
    ~Problem ();

    void run ();

  private:
    void setup_system ();
    void assemble_system ();
    void solve ();
    void output_results () const;

    MPI_Comm                                  mpi_communicator;

    parallel::distributed::Triangulation<2>   triangulation;

    DoFHandler<2>                             dof_handler;
    FE_Q<2>                                   fe;

    IndexSet                                  locally_owned_dofs;
    IndexSet                                  locally_relevant_dofs;

    ConstraintMatrix                          constraints;

    PETScWrappers::MPI::SparseMatrix          system_matrix;
    PETScWrappers::MPI::Vector                locally_relevant_solution;
    PETScWrappers::MPI::Vector                system_rhs;

    ConditionalOStream                        pcout;
  };



  Problem::Problem ()
    : mpi_communicator (MPI_COMM_WORLD)
    , triangulation (mpi_communicator)
    , dof_handler (triangulation)
    , fe (1)
    , pcout (std::cout, Utilities::MPI::this_mpi_process(mpi_communicator) == 0)
  {}



  Problem::~Problem ()
  {
    dof_handler.clear ();
  }



  void Problem::setup_system ()
  {
    dof_handler.distribute_dofs (fe);

    locally_owned_dofs = dof_handler.locally_owned_dofs ();
    DoFTools::extract_locally_relevant_dofs (dof_handler,
                                             locally_relevant_dofs);

    locally_relevant_solution.reinit (locally_owned_dofs,
                                      locally_relevant_dofs, mpi_communicator);
    system_rhs.reinit (locally_owned_dofs, mpi_communicator);

    aux::AnalyticalSolution exact_sol;
    aux::UpperCrack upper_crack;
    aux::LowerCrack lower_crack;
    aux::UpperEdge upper_edge;
    aux::LowerEdge lower_edge;

    constraints.clear ();
    constraints.reinit (locally_relevant_dofs);
    DoFTools::make_hanging_node_constraints (dof_handler, constraints);

    for(unsigned int id(0); id < 3; ++id)
      dealii::VectorTools::interpolate_boundary_values(
        dof_handler, id, exact_sol, constraints);

    dealii::VectorTools::interpolate_boundary_values(
      dof_handler, 3, upper_edge, constraints);

    dealii::VectorTools::interpolate_boundary_values(
      dof_handler, 4, upper_crack, constraints);

    dealii::VectorTools::interpolate_boundary_values(
      dof_handler, 5, lower_crack, constraints);

    dealii::VectorTools::interpolate_boundary_values(
      dof_handler, 6, lower_edge, constraints);

    constraints.close ();

    DynamicSparsityPattern dsp (locally_relevant_dofs);

    DoFTools::make_sparsity_pattern (dof_handler, dsp,
                                     constraints, false);

    SparsityTools::distribute_sparsity_pattern (dsp,
                                                dof_handler.n_locally_owned_dofs_per_processor(),
                                                mpi_communicator,
                                                locally_relevant_dofs);

    system_matrix.reinit (locally_owned_dofs,
                          locally_owned_dofs,
                          dsp,
                          mpi_communicator);
  }



  void Problem::assemble_system ()
  {
    aux::AnalyticalSolution exact_sol;

    const QGauss<2>  quadrature_formula(2);

    FEValues<2> fe_values (fe, quadrature_formula, update_values
                                                 | update_gradients
                                                 | update_quadrature_points
                                                 | update_JxW_values);

    const unsigned int   dofs_per_cell = fe.dofs_per_cell;
    const unsigned int   n_q_points    = quadrature_formula.size();

    FullMatrix<double>   cell_matrix (dofs_per_cell, dofs_per_cell);
    Vector<double>       cell_rhs (dofs_per_cell);

    std::vector<types::global_dof_index> local_dof_indices (dofs_per_cell);

    typename DoFHandler<2>::active_cell_iterator
      cell = dof_handler.begin_active(),
      endc = dof_handler.end();

    for (; cell!=endc; ++cell)
      if (cell->is_locally_owned()) {
        cell_matrix = 0;
        cell_rhs = 0;

        fe_values.reinit (cell);

        for (unsigned int q_point=0; q_point<n_q_points; ++q_point) {
          const Point<2> & p = fe_values.quadrature_point(q_point);

          const double rhs_value = exact_sol.forcing(p);

          for (unsigned int i=0; i<dofs_per_cell; ++i) {
            for (unsigned int j=0; j<dofs_per_cell; ++j)
              cell_matrix(i,j) += (fe_values.shape_grad(i,q_point) *
                                   fe_values.shape_grad(j,q_point) *
                                   fe_values.JxW(q_point));

              cell_rhs(i) += (rhs_value *
                              fe_values.shape_value(i,q_point) *
                              fe_values.JxW(q_point));
          }
        }

        cell->get_dof_indices (local_dof_indices);
        constraints.distribute_local_to_global (cell_matrix,
                                                cell_rhs,
                                                local_dof_indices,
                                                system_matrix,
                                                system_rhs);
      }

    system_matrix.compress (VectorOperation::add);
    system_rhs.compress (VectorOperation::add);
  }



  void Problem::solve ()
  {
    PETScWrappers::MPI::Vector
      completely_distributed_solution (locally_owned_dofs, mpi_communicator);

    SolverControl solver_control (dof_handler.n_dofs(), 1e-12);

    PETScWrappers::SolverCG solver(solver_control, mpi_communicator);

    PETScWrappers::PreconditionBoomerAMG preconditioner;

    PETScWrappers::PreconditionBoomerAMG::AdditionalData data;

    data.symmetric_operator = true;

    preconditioner.initialize(system_matrix, data);

    solver.solve (system_matrix, completely_distributed_solution, system_rhs,
                  preconditioner);

    pcout << "   Solved in " << solver_control.last_step()
          << " iterations." << std::endl;

    constraints.distribute (completely_distributed_solution);

    locally_relevant_solution = completely_distributed_solution;
  }



  void Problem::output_results () const
  {
    dealii::DataOut<2> data_out;

    data_out.attach_dof_handler(dof_handler);
    data_out.add_data_vector(locally_relevant_solution, "solution");

    std::string filename;

    const unsigned int n_processes =
      dealii::Utilities::MPI::n_mpi_processes(mpi_communicator);

    const unsigned int rank =
      dealii::Utilities::MPI::this_mpi_process(mpi_communicator);

    if(n_processes > 1){
      filename = "solution-rank"
                 + dealii::Utilities::int_to_string(rank, 2)
                 + ".vtu";
    }
    else{
      filename = "solution.vtu";
    }

    data_out.build_patches();

    std::ofstream output(filename);
    data_out.write_vtu(output);
  }



  void Problem::run ()
  {
    GridGenerator::hyper_cube_slit(triangulation, -1, 1, false);
    GridTools::rotate(-M_PI/2, triangulation);

    triangulation.refine_global (6);

    // set boundary indicators
    {
      FE_Nothing<2> fe_nothing(2);
      QMidpoint<1> quad;
      FEFaceValues<2> fe_face_values(fe_nothing, quad, update_normal_vectors);

      parallel::distributed::Triangulation<2>::active_cell_iterator
        cell(triangulation.begin_active()),
        endc(triangulation.end());

      for(; cell != endc; ++cell) {
        if(cell->is_locally_owned()) {
          for(unsigned int f(0); f<GeometryInfo<2>::faces_per_cell; ++f) {
            if(cell->face(f)->at_boundary()) {
              fe_face_values.reinit(cell, f);

              Point<2> face_center = cell->face(f)->center();

              Tensor<1,2> n = fe_face_values.normal_vector(0);
              const double linfty = std::max(std::abs(n[0]), std::abs(n[1]));
              n /= linfty;

              const double x = face_center[0];
              const double y = face_center[1];

              if((n[1] == -1) and (std::abs(y+1) < 1.e-13))
                cell->face(f)->set_all_boundary_ids(0);

              if((n[0] == 1) and (std::abs(x-1) < 1.e-13))
                cell->face(f)->set_all_boundary_ids(1);

              if((n[1] == 1) and (std::abs(y-1) < 1.e-13))
                cell->face(f)->set_all_boundary_ids(2);

              if((n[0] == -1) and (std::abs(x+1) < 1.e-13) and (y > 0))
                cell->face(f)->set_all_boundary_ids(3);

              if((n[1] == -1) and (std::abs(y) < 1.e-13))
                cell->face(f)->set_all_boundary_ids(4);

              if((n[1] == 1) and (std::abs(y) < 1.e-13))
                cell->face(f)->set_all_boundary_ids(5);

              if((n[0] == -1) and (std::abs(x+1) < 1.e-13) and (y < 0))
                cell->face(f)->set_all_boundary_ids(6);
            }
          }
        }
      }
    }

    setup_system ();

    pcout << "   Number of active cells:       "
          << triangulation.n_global_active_cells()
          << std::endl
          << "   Number of degrees of freedom: "
          << dof_handler.n_dofs()
          << std::endl;

    assemble_system ();
    solve ();

    if (Utilities::MPI::n_mpi_processes(mpi_communicator) <= 32) {
      output_results ();
    }

    pcout << std::endl;
}



int main(int argc, char *argv[])
{
  try
    {
      using namespace dealii;
      using namespace CrackedDomain;

      Utilities::MPI::MPI_InitFinalize mpi_initialization(argc, argv, 1);

      Problem problem;
      problem.run ();
    }
  catch (std::exception &exc)
    {
      std::cerr << std::endl << std::endl
                << "----------------------------------------------------"
                << std::endl;
      std::cerr << "Exception on processing: " << std::endl
                << exc.what() << std::endl
                << "Aborting!" << std::endl
                << "----------------------------------------------------"
                << std::endl;

      return 1;
    }
  catch (...)
    {
      std::cerr << std::endl << std::endl
                << "----------------------------------------------------"
                << std::endl;
      std::cerr << "Unknown exception!" << std::endl
                << "Aborting!" << std::endl
                << "----------------------------------------------------"
                << std::endl;
      return 1;
    }

  return 0;
}

}
