import ClampedErrorCorrectionFilter from "GazeFilters/ClampedErrorCorrectionFilter";
import ErrorCorrectionFilter from "GazeFilters/ErrorCorrectionFilter";
import SigmoidErrorCorrectionFilter from "GazeFilters/SigmoidErrorCorrectionFilter";
import { GazeValidatorConfig } from "Managers/ConfigManager";
import { add, dotMultiply, mean, sqrt, std } from "mathjs";
import ErrorCorrection from "Models/ErrorCorrection";
import Gaze from "Models/Gaze";
import ValidatedErrorCorrection, {
  VECPoint,
} from "Models/ValidatedErrorCorrection";

export default class ErrorCorrectionGazeValidator {
  config: GazeValidatorConfig;
  _gaze: Gaze; // most recent validation gaze

  _errorCorrectionFilter: ErrorCorrectionFilter; // The error correction filter which is used by the validator

  _clampedFilter: ErrorCorrectionFilter; // The most recent validation

  _errorCorrection: ErrorCorrection; // The most recent error correction

  _validatedErrorCorrection: ValidatedErrorCorrection;

  constructor(config: GazeValidatorConfig) {
    this.config = config;
    // ErrorCorrectionFilterFactory.build(config);
    this._errorCorrectionFilter = new SigmoidErrorCorrectionFilter(
      config.filterProperties.nNearest,
      config.filterProperties.beta,
      ErrorCorrection.Zero()
    );
    this._errorCorrection = ErrorCorrection.Zero();
    this._validatedErrorCorrection = ValidatedErrorCorrection.Zero();
    this._clampedFilter = new ClampedErrorCorrectionFilter(
      config.filterProperties.clamp
    );
    this._gaze = Gaze.Zero();
  }

  /**
   * @return {ErrorCorrection}
   */
  errorCorrection = (): ErrorCorrection => this._errorCorrection;

  /**
   * @return {ValidatedErrorCorrection}
   */
  validatedErrorCorrection = (): ValidatedErrorCorrection =>
    this._validatedErrorCorrection;

  /**
   * @return {ErrorCorrectionFilter} errorCorrection
   */
  errorCorrectionFilter = (): ErrorCorrectionFilter =>
    this._errorCorrectionFilter;

  /**
   * @param {ErrorCorrection} errorCorrection
   */
  setErrorCorrection = (errorCorrection: ErrorCorrection) => {
    this._errorCorrectionFilter = new SigmoidErrorCorrectionFilter(
      this.config.filterProperties.nNearest,
      this.config.filterProperties.beta,
      errorCorrection
    );
  };

  /**
   * reset calibrator
   */
  reset = async () => {
    this._gaze = Gaze.Zero();

    this._errorCorrectionFilter.resetSamples();
  };

  /**
   * @param {Gaze} gaze associated with validation sample
   */
  addValidationData = (predictedGaze: Gaze, measuredGaze: Gaze) => {
    this._errorCorrectionFilter.addSample(predictedGaze, measuredGaze);
  };

  validate = () => {
    let errorCorrectionFilter = this._errorCorrectionFilter;
    let errorCorrection = this.errorCorrection();

    if (!errorCorrection || errorCorrection.points().length == 0) {
      errorCorrection = errorCorrectionFilter.calculateErrorCorrection();
    }

    let predicted = errorCorrectionFilter.predictedGazes(); // the time-averaged raw gaze

    let modified = predicted.map((p) => errorCorrectionFilter.apply(p));
    let sampled = errorCorrectionFilter.sampledGazes(); // the reference points

    let timestamp = predicted[0].timestamp();
    let duration = 0.0;
    let dict: Record<string, any> = {};
    let keys = [];

    /**
     * make sure bins actually correspond to error correction points
     */
    for (let ecPoint of errorCorrection.points()) {
      let key = "(" + ecPoint.x() + ", " + ecPoint.y() + ")";
      dict[key] = {
        sampled: null,
        predicted: [],
        modified: [],
      };
      keys.push(key);
    }

    /**
     * bin the data
     */
    for (let i = 0; i < sampled.length; i++) {
      let key = "(" + sampled[i].x() + ", " + sampled[i].y() + ")";

      if (dict[key]) {
        dict[key].predicted.push(predicted[i]);
        dict[key].modified.push(modified[i]);
        dict[key].sample = sampled[i];
      }

      duration += predicted[i].duration();
    }

    /**
     * find out stats for each bin
     */
    let vecPoints = [];

    for (let ecPoint of errorCorrection.points()) {
      let key = "(" + ecPoint.x() + ", " + ecPoint.y() + ")";
      let sample = dict[key].sample;
      let dxs = dict[key].modified.map(
        (p: { x: () => number }) => p.x() - sample.x()
      );
      let dys = dict[key].modified.map(
        (p: { y: () => number }) => p.y() - sample.y()
      );
      let dr2s = add(
        dotMultiply(dxs, dxs),
        dotMultiply(dys, dys)
      ) as Array<number>;
      let accuracy = mean(dr2s.map((x: number) => sqrt(x)));
      let xAccuracy = mean(dxs);
      let yAccuracy = mean(dys);
      let xPrecision = std(dxs);
      let yPrecision = std(dys);
      let precision = sqrt(
        xPrecision * xPrecision + yPrecision * yPrecision
      ) as number;
      vecPoints.push(
        new VECPoint(
          ecPoint.x(),
          ecPoint.y(),
          ecPoint.xCorrection(),
          ecPoint.yCorrection(),
          precision,
          xPrecision,
          yPrecision,
          accuracy,
          xAccuracy,
          yAccuracy
        )
      );
    }

    /**
     * find out bulk stats
     */
    let dxs = [];
    let dys = [];

    for (let i = 0; i < predicted.length; i++) {
      dxs.push(modified[i].x() - sampled[i].x());
      dys.push(modified[i].y() - sampled[i].y());
    }

    let dr2s = add(
      dotMultiply(dxs, dxs),
      dotMultiply(dys, dys)
    ) as Array<number>;
    let accuracy = mean(dr2s.map((x: number) => sqrt(x)));
    let xAccuracy = mean(dxs);
    let yAccuracy = mean(dys);

    /**
     * Use the old definition of precision instead. Seems wrong to me.
     * let xPrecision = std(dxs);
     * let yPrecision = std(dys);
     * let precision = sqrt(xPrecision * xPrecision + yPrecision * yPrecision);
     */
    let xPrecision = mean(vecPoints.map((vecPoint) => vecPoint.xPrecision()));
    let yPrecision = mean(vecPoints.map((vecPoint) => vecPoint.yPrecision()));
    let precision = mean(vecPoints.map((vecPoint) => vecPoint.precision()));
    let misc = {
      errorCorrectionTimestamp: errorCorrection.timestamp(),
      errorCorrectionDuration: errorCorrection.duration(),
      errorCorrectionFramerate:
        (1000.0 * errorCorrection.sampleCount()) / errorCorrection.duration(),
      validatedErrorCorrectionFramerate: (1000.0 * predicted.length) / duration,
    };
    let validatedErrorCorrection = new ValidatedErrorCorrection(
      timestamp,
      duration,
      vecPoints,
      predicted.length,
      precision,
      xPrecision,
      yPrecision,
      accuracy,
      xAccuracy,
      yAccuracy,
      misc
    );
    this._errorCorrection = errorCorrectionFilter.calculateErrorCorrection();
    this._validatedErrorCorrection = validatedErrorCorrection;

    if (this.config.resetCorrections) {
      this._errorCorrectionFilter.resetSamples();
    }

    return validatedErrorCorrection;
  };

  /**
   * @param {Gaze} predictedGaze
   * @return {Gaze}
   */
  update = (predictedGaze: Gaze): Gaze => {
    let gaze = this._errorCorrectionFilter.apply(predictedGaze);

    // this._gaze = this._clampedFilter.apply(gaze);
    this._gaze = gaze;
    return this._gaze;
  };

  /**
   * @return {Boolean}
   */
  hasNextGaze = (): boolean => this._gaze != null;

  /**
   * @return {Gaze}
   */
  nextGaze = (): Gaze => this._gaze;
}
