#include <deal.II/base/quadrature_lib.h>
#include <deal.II/base/function.h>
#include <deal.II/base/logstream.h>
#include <deal.II/base/multithread_info.h>
#include <deal.II/lac/vector.h>
#include <deal.II/lac/full_matrix.h>
#include <deal.II/lac/constraint_matrix.h>
#include <deal.II/lac/dynamic_sparsity_pattern.h>
#include <deal.II/lac/sparsity_tools.h>
#include <deal.II/grid/tria.h>
#include <deal.II/grid/grid_generator.h>
#include <deal.II/grid/grid_refinement.h>
#include <deal.II/grid/tria_accessor.h>
#include <deal.II/grid/tria_iterator.h>
#include <deal.II/grid/tria_boundary_lib.h>
#include <deal.II/dofs/dof_handler.h>
#include <deal.II/dofs/dof_accessor.h>
#include <deal.II/dofs/dof_tools.h>
#include <deal.II/fe/fe_values.h>
#include <deal.II/fe/fe_system.h>
#include <deal.II/fe/fe_q.h>
#include <deal.II/numerics/vector_tools.h>
#include <deal.II/numerics/matrix_tools.h>
#include <deal.II/numerics/data_out.h>
#include <deal.II/numerics/error_estimator.h>

#include <deal.II/base/mpi.h>

#include <deal.II/lac/petsc_parallel_vector.h>
#include <deal.II/lac/petsc_parallel_sparse_matrix.h>

#include <deal.II/lac/petsc_solver.h>
#include <deal.II/lac/petsc_precondition.h>

#include <deal.II/grid/grid_tools.h>
#include <deal.II/dofs/dof_renumbering.h>

#include <deal.II/distributed/tria.h>
#include <deal.II/distributed/grid_refinement.h>

#include <fstream>
#include <iostream>
#include <sstream>

namespace DistributedElasticity
{
  using namespace dealii;

  template <int dim>
  class RightHandSide :  public Function<dim>
  {
  public:
    RightHandSide ();

    virtual void vector_value (const Point<dim> &p,
			       Vector<double> &values) const;

    virtual void vector_value_list (const std::vector<Point<dim> > &points,
			 std::vector<Vector<double> > &value_list) const;
  };


  template <int dim>
  RightHandSide<dim>::RightHandSide () :
	    Function<dim> (dim)
  {}


  template <int dim>
  inline
  void RightHandSide<dim>::vector_value (const Point<dim> &p,
				      Vector<double> &values) const
    {
    Assert (values.size() == dim,
	    ExcDimensionMismatch (values.size(), dim));
    Assert (dim >= 2, ExcInternalError());

    Point<dim> point_1, point_2;
    point_1(0) = 0.5;
    point_2(0) = -0.5;

    if (((p-point_1).norm_square() < 0.2*0.2) ||
	((p - point_2).norm_square () < 0.2 * 0.2))
      values(0) = 1;
    else
      values(0) = 0;

    if (p.square() < 0.2*0.2)
      values(1) = 1;
    else
      values(1) = 0;
    }



  template <int dim>
  void RightHandSide<dim>::vector_value_list (const std::vector<Point<dim> > &points,
	std::vector<Vector<double> > &value_list) const
    {
    const unsigned int n_points = points.size();

    Assert (value_list.size() == n_points,
	     ExcDimensionMismatch (value_list.size(), n_points));

    for (unsigned int p=0; p<n_points; ++p)
      RightHandSide<dim>::vector_value (points[p],
					value_list[p]);
    }

  template<int dim>
    class ElasticProblem
    {
    public:
      ElasticProblem ();
      ~ElasticProblem ();
      void
      run ();

    private:
      void
      setup_system ();
      void
      assemble_system ();
      unsigned int
      solve ();
      void
      refine_grid ();

      MPI_Comm mpi_communicator;

      const unsigned int n_mpi_processes;
      const unsigned int this_mpi_process;

      parallel::distributed::Triangulation<dim> triangulation;
      DoFHandler<dim> dof_handler;

      FESystem<dim> fe;

      IndexSet locally_owned_dofs;
      IndexSet locally_relevant_dofs;

      ConstraintMatrix constraints;

      PETScWrappers::MPI::SparseMatrix system_matrix;
      PETScWrappers::MPI::Vector locally_relevant_solution;
      PETScWrappers::MPI::Vector system_rhs;
    };

  template <int dim>
  ElasticProblem<dim>::ElasticProblem ()
  :
	    mpi_communicator (MPI_COMM_WORLD),
	    n_mpi_processes (
		Utilities::MPI::n_mpi_processes (mpi_communicator)),
	    this_mpi_process (
		Utilities::MPI::this_mpi_process (mpi_communicator)),
	    triangulation (
		mpi_communicator,
		typename Triangulation<dim>::MeshSmoothing (
		    Triangulation<dim>::smoothing_on_refinement
			| Triangulation<dim>::smoothing_on_coarsening)),
	    dof_handler (triangulation),
	    fe (FE_Q<dim> (1), dim)
    {
    }

  template <int dim>
  ElasticProblem<dim>::~ElasticProblem ()
  {
    dof_handler.clear ();
  }


  template <int dim>
  void ElasticProblem<dim>::setup_system ()
    {
      dof_handler.distribute_dofs (fe);

      locally_owned_dofs = dof_handler.locally_owned_dofs ();
      DoFTools::extract_locally_relevant_dofs (dof_handler,
					       locally_relevant_dofs);

      locally_relevant_solution.reinit (locally_owned_dofs,
					locally_relevant_dofs,
					mpi_communicator);
      system_rhs.reinit (locally_owned_dofs, mpi_communicator);
      constraints.clear ();
      constraints.reinit (locally_relevant_dofs);
      DoFTools::make_hanging_node_constraints (dof_handler,
					     constraints);
      constraints.close ();


    DynamicSparsityPattern dsp (locally_relevant_dofs);
    DoFTools::make_sparsity_pattern (dof_handler, dsp,
				     constraints, false);
      SparsityTools::distribute_sparsity_pattern (
	  dsp, dof_handler.n_locally_owned_dofs_per_processor (),
	  mpi_communicator, locally_relevant_dofs);

    system_matrix.reinit (locally_owned_dofs,
			  locally_owned_dofs, dsp,
			    mpi_communicator);

      locally_relevant_solution.reinit (locally_owned_dofs, mpi_communicator);
    system_rhs.reinit (locally_owned_dofs, mpi_communicator);
  }

  template <int dim>
  void ElasticProblem<dim>::assemble_system ()
  {
    QGauss<dim>  quadrature_formula(2);
    FEValues<dim> fe_values (fe, quadrature_formula,
	  update_values | update_gradients | update_quadrature_points
	      | update_JxW_values);

    const unsigned int   dofs_per_cell = fe.dofs_per_cell;
    const unsigned int   n_q_points    = quadrature_formula.size();

    FullMatrix<double>   cell_matrix (dofs_per_cell, dofs_per_cell);
    Vector<double>       cell_rhs (dofs_per_cell);

    std::vector<types::global_dof_index> local_dof_indices (dofs_per_cell);

    std::vector<double>     lambda_values (n_q_points);
    std::vector<double>     mu_values (n_q_points);

      Functions::ConstantFunction<dim> lambda (1.), mu (1.);

    RightHandSide<dim>      right_hand_side;
    std::vector<Vector<double> > rhs_values (n_q_points,
					       Vector<double> (dim));

    typename DoFHandler<dim>::active_cell_iterator
    cell = dof_handler.begin_active(),
    endc = dof_handler.end();
    for (; cell!=endc; ++cell)
      if (cell->subdomain_id() == this_mpi_process)
	  {
	    cell_matrix = 0;
	    cell_rhs = 0;

	    fe_values.reinit (cell);

	    lambda.value_list (fe_values.get_quadrature_points (),
			       lambda_values);
	    mu.value_list (fe_values.get_quadrature_points (), mu_values);

	    for (unsigned int i = 0; i < dofs_per_cell; ++i)
	      {
		const unsigned int component_i = fe.system_to_component_index (
		    i).first;

		for (unsigned int j = 0; j < dofs_per_cell; ++j)
		  {
		    const unsigned int component_j =
			fe.system_to_component_index (j).first;

		    for (unsigned int q_point = 0; q_point < n_q_points;
			++q_point)
		      {
			cell_matrix (i, j) +=
			    ((fe_values.shape_grad (i, q_point)[component_i]
				* fe_values.shape_grad (j, q_point)[component_j]
				* lambda_values[q_point])
				+ (fe_values.shape_grad (i, q_point)[component_j]
				    * fe_values.shape_grad (j, q_point)[component_i]
				    * mu_values[q_point])
				+ ((component_i == component_j) ?
				    (fe_values.shape_grad (i, q_point)
					* fe_values.shape_grad (j, q_point)
					* mu_values[q_point]) :
				    0)) * fe_values.JxW (q_point);
		      }
		  }
	      }

	    right_hand_side.vector_value_list (
		fe_values.get_quadrature_points (), rhs_values);
	    for (unsigned int i = 0; i < dofs_per_cell; ++i)
	      {
		const unsigned int component_i = fe.system_to_component_index (
		    i).first;

		for (unsigned int q_point = 0; q_point < n_q_points; ++q_point)
		  cell_rhs (i) += fe_values.shape_value (i, q_point)
		      * rhs_values[q_point] (component_i)
		      * fe_values.JxW (q_point);
	      }

	    cell->get_dof_indices (local_dof_indices);
	    constraints.distribute_local_to_global (cell_matrix, cell_rhs,
						    local_dof_indices,
						    system_matrix, system_rhs);
	  }

    system_matrix.compress(VectorOperation::add);
    system_rhs.compress(VectorOperation::add);

    std::map<types::global_dof_index,double> boundary_values;
    VectorTools::interpolate_boundary_values (dof_handler,
					      0,
						ZeroFunction<dim> (dim),
						boundary_values);
    MatrixTools::apply_boundary_values (boundary_values,
					system_matrix,
					  locally_relevant_solution, system_rhs,
					  false);
  }

  template <int dim>
  unsigned int ElasticProblem<dim>::solve ()
  {
      SolverControl solver_control (locally_relevant_solution.size (),
				    1e-8 * system_rhs.l2_norm ());
      PETScWrappers::SolverCG cg (solver_control,
				mpi_communicator);

      PETScWrappers::PreconditionBlockJacobi preconditioner (system_matrix);
      PETScWrappers::MPI::Vector distributed_solution (locally_owned_dofs,
						       mpi_communicator);

      constraints.set_zero (distributed_solution);

      cg.solve (system_matrix, distributed_solution, system_rhs,
		preconditioner);

      constraints.distribute (distributed_solution);

      locally_relevant_solution = distributed_solution;

      return solver_control.last_step ();
  }


  template <int dim>
  void ElasticProblem<dim>::refine_grid ()
  {
      Vector<float> estimated_error_per_cell (triangulation.n_active_cells ());
      KellyErrorEstimator<dim>::estimate (
	  dof_handler, QGauss<dim - 1> (3), typename FunctionMap<dim>::type (),
	  locally_relevant_solution, estimated_error_per_cell, ComponentMask (),
	  nullptr, 0, triangulation.locally_owned_subdomain ());
      parallel::distributed::GridRefinement::refine_and_coarsen_fixed_number (
	  triangulation, estimated_error_per_cell, 0.3, 0.03);
      triangulation.execute_coarsening_and_refinement ();
  }


  template <int dim>
  void ElasticProblem<dim>::run ()
  {
    for (unsigned int cycle=0; cycle<10; ++cycle)
      {
	  if (cycle == 0)
	    {
	      GridGenerator::hyper_cube (triangulation, -1, 1);
	      triangulation.refine_global (3);
	    }
	  else
	    refine_grid ();

	  setup_system ();
	  assemble_system ();
	  solve ();
      }
  }
}


int main (int argc, char **argv)
{
  try
    {
      using namespace dealii;
      using namespace DistributedElasticity;

      Utilities::MPI::MPI_InitFinalize mpi_initialization(argc, argv, 1);

      ElasticProblem<2> elastic_problem;
      elastic_problem.run ();
    }
  catch (std::exception &exc)
    {
      std::cerr << std::endl << std::endl
	  << "----------------------------------------------------"
	  << std::endl;
      std::cerr << "Exception on processing: " << std::endl
	  << exc.what ()
	  << std::endl << "Aborting!" << std::endl
	  << "----------------------------------------------------"
	  << std::endl;

      return 1;
    }
  catch (...)
    {
      std::cerr << std::endl << std::endl
	  << "----------------------------------------------------"
	  << std::endl;
      std::cerr << "Unknown exception!" << std::endl
	  << "Aborting!" << std::endl
	  << "----------------------------------------------------"
	  << std::endl;
      return 1;
    }
  std::cout << "DistributedElasticity executed successfully." << std::endl;
  return 0;
}
