#ifndef ADAPTIVESAMPLING_HPP
#define	ADAPTIVESAMPLING_HPP

#include <algorithm>
#include <utility>
#include <memory>
#include <random>

#include "Metric.hpp"
#include "WeightedPoint.hpp"
#include "Randomness.hpp"

/**
 * @brief Weighted sampling for k-means++ and similar algorithms
 */
class AdaptiveSampling
{
private:
    Metric<Point>* metric;
public:
    AdaptiveSampling(std::function<Metric<Point>*() > createMetric);
    virtual ~AdaptiveSampling();

    /**
     * @brief Computes a center set
     * @param begin Input point set iterator: begin
     * @param end Input point set iterator: end
     * @param k Number of centers
     * @param n Number of points (optional)
     * @return k centers
     */
    template<typename ForwardIterator>
    std::unique_ptr<std::vector<Point >> computeCenterSet(ForwardIterator begin, ForwardIterator end, size_t k, size_t n = 0);
};

template<typename ForwardIterator>
std::unique_ptr<std::vector<Point >> AdaptiveSampling::computeCenterSet(ForwardIterator begin, ForwardIterator end, size_t k, size_t n)
{
    if (n == 0)
        for (ForwardIterator it = begin; it != end; ++it)
            ++n;

    std::mt19937 * rand = Randomness::getMT19937();
    std::unique_ptr < std::vector < Point >> centers (new std::vector<Point>(k));

    // Draw first center
    std::uniform_int_distribution<> uniformFirst(0, n - 1);
    size_t firstCenterIndex = uniformFirst(*rand);
    ForwardIterator firstCenter = begin;
    for (size_t i = 0; i < firstCenterIndex; ++i)
        ++firstCenter;
    (*centers)[0] = *firstCenter;

    // Draw remaining centers
    std::uniform_real_distribution<> uniformReal(0, 1);
    std::vector<double> weights(n, std::numeric_limits<double>::infinity());
    for (size_t c = 1; c < k; ++c)
    {
        double sumWeights = 0;
        // Update nearest center propterty of all points
        {
            ForwardIterator it = begin;
            for (size_t p = 0; p < n; ++p)
            {
                double dist = metric->distance((*centers)[c - 1], *it);
                if (dist < weights[p])
                    weights[p] = dist;
                sumWeights += weights[p];
                ++it;
            }
        }
        // New center
        {
            // Draw random number / center
            double nextCenterCumWeight = uniformReal(*rand);
            // Determine new center
            double nextCenterCumSearch = 0;
            size_t nextCenterIndex = 0;
            ForwardIterator it = begin;
            for (size_t p = 0; p < n; ++p)
            {
                nextCenterIndex = p;
                if (p > 0)
                    ++it;
                double weight = weights[p];
                if (nextCenterCumSearch + weight > nextCenterCumWeight)
                    break;
                else
                    nextCenterCumSearch += weight;
            }
            (*centers)[c] = *it;
        }
    }

    return centers;
}

#endif	/* ADAPTIVESAMPLING_HPP */

