import ErrorCorrectionFilter from "GazeFilters/ErrorCorrectionFilter";
import { mean, sum } from "mathjs";
import ErrorCorrection, { ECPoint } from "Models/ErrorCorrection";
import Gaze from "Models/Gaze";

export default class SigmoidErrorCorrectionFilter extends ErrorCorrectionFilter {
  _errorCorrection: ErrorCorrection;
  _nNearest: number;
  _beta: number;

  /**
   * @param {ErrorCorrection} errorCorrection
   * @param {number} n_nearest points to consider
   * @param {number} beta is 2x the slope of the distance normalized sigmoid function
   */
  constructor(
    nNearest: number = 2,
    beta: number = 7.0,
    errorCorrection: ErrorCorrection = ErrorCorrection.Zero()
  ) {
    super();
    this._errorCorrection = errorCorrection;
    this._nNearest = nNearest;
    this._beta = beta;
  }

  /**
   * @return {ErrorCorrection} the previously calculated error correction
   */
  errorCorrection = (): ErrorCorrection => {
    return this._errorCorrection;
  };

  /**
   * apply the filter
   * x = x_0 + sum_{i=0}^n w_i * c_x, y = y_0 + sum_{i=0}^n w_i * c_y
   * where c_x and c_y are the validated point corrections
   * and w(s) = 1 / ( 1 + exp[ k * ( s - < s > ) / < s > ])
   * is the distance normalized sigmoid function
   * @param {Gaze} gaze
   * @return {Gaze}
   */
  apply = (gaze: Gaze): Gaze => {
    let errorCorrection = this._errorCorrection;

    if (!errorCorrection || errorCorrection.points().length == 0) {
      return new Gaze(
        gaze.index(),
        gaze.timestamp(),
        gaze.duration(),
        gaze.x(),
        gaze.y()
      );
    }

    if (this._nNearest > errorCorrection.points().length) {
      throw "SigmoidErrorCorrectionFilter: not enough validated points!";
    }

    // sort points by distance
    let allPoints = Array.from(errorCorrection.points());
    allPoints.sort((a, b) => a.distance(gaze) - b.distance(gaze));
    // take n nearest points
    // question was raised about whether this should be n nearest points after applying modification
    // old tracker finds distance before applying modification
    let points = allPoints.slice(0, this._nNearest);
    let distances = points.map((p) => p.distance(gaze));
    // calculate average distance
    let s0 = sum(distances) / distances.length;
    // find weights using sigmoid (note: sign flipped so that distance of 0 gives weight of 1)
    // note: extra factor of 2. this is because: beta = 0.5 * k
    let weights = distances.map(
      (s) => 1.0 / (1.0 + Math.exp((2 * this._beta * (s - s0)) / s0))
    );
    // normalize weights
    let sumWeights = sum(weights);
    weights = weights.map((w) => w / sumWeights);
    // apply correction
    let dx = 0;
    let dy = 0;

    for (let i = 0; i < points.length; i++) {
      dx += weights[i] * points[i].xCorrection();
      dy += weights[i] * points[i].yCorrection();
    }

    return new Gaze(
      gaze.index(),
      gaze.timestamp(),
      gaze.duration(),
      gaze.x() + dx,
      gaze.y() + dy
    );
  };

  /**
   * calculate and update error correction from samples
   * @return {ErrorCorrection}
   */
  calculateErrorCorrection = (): ErrorCorrection => {
    /**
     * Bin predicted gaze by sample
     * This is quite painful without proper data structures...
     */
    let predicted = this.predictedGazes();
    let sampled = this.sampledGazes();
    let timestamp = predicted[0].timestamp();
    let duration = 0.0;
    let dict: Record<string, any> = {};

    for (let i = 0; i < sampled.length; i++) {
      let key = "(" + sampled[i].x() + ", " + sampled[i].y() + ")";

      if (!dict[key]) {
        dict[key] = {
          sample: sampled[i],
          predicted: [],
        };
      }

      dict[key].predicted.push(predicted[i]);
      duration += predicted[i].duration();
    }

    let ecPoints = [];

    for (let key in dict) {
      let sample = dict[key].sample;
      let xMean = mean(dict[key].predicted.map((p: { x: () => any }) => p.x()));
      let yMean = mean(dict[key].predicted.map((p: { y: () => any }) => p.y()));
      let dx = sample.x() - xMean;
      let dy = sample.y() - yMean;
      ecPoints.push(new ECPoint(sample.x(), sample.y(), dx, dy));
    }

    let errorCorrection = new ErrorCorrection(
      timestamp,
      duration,
      ecPoints,
      predicted.length
    );
    this._errorCorrection = errorCorrection;
    return errorCorrection;
  };
}
