#include <sli/fitscc.h>
#include <sli/mdarray_statistics.h>
#include <stdlib.h>
#include <vector> // 可変長配列
#include <stdio.h>
#include <assert.h>
#include <math.h>
#include "mask.h"


using namespace sli;


static void subtract_sky(fits_image &hdu, double *sky_stddev) {
    mdarray_float data = hdu.float_array();
    const int iter_n = 5;
    const double sigma = 2.0;
    for (int i = 0;  i < iter_n;  i++) {
        double mean = md_mean(data),
               stddev = md_stddev(data);
        for (unsigned y = 0;  y < data.length(1);  y++) {
            for (unsigned x = 0;  x < data.length(0);  x++) {
                if ((data(x, y) - mean) / stddev > sigma)
                    data(x, y) = NAN;
            }
        }
    }
    hdu.float_array() -= md_mean(data);
    *sky_stddev = md_stddev(data);
}


static void mark_detected(fitscc &fits, double sky_stddev) {
    const double threshold = 3.0;

    mdarray_float &data = fits.image(0L).float_array();

    if (fits.length() < 2)
        fits.append_image("Mask", 0, FITS::BYTE_T, data.length(0), data.length(1));

    mdarray_uchar &mask = fits.image(1L).uchar_array();

    for (unsigned y = 0;  y < data.length(1);  y++) {
        for (unsigned x = 0 ;  x < data.length(0);  x++) {
            if (data(x, y) > sky_stddev * threshold)
                mask(x, y) |= DETECTED;
        }
    }
}


typedef struct {
    int x;
    int y;
} point_t;


static void containing_box(const std::vector<point_t> &pixels, int *min_x, int *max_x, int *min_y, int *max_y) {
    assert(pixels.size() > 0); // assertの括弧内の条件が満たされていなければその場で終了する
    *min_x = *max_x = pixels[0].x;
    *min_y = *max_y = pixels[0].y;
    for (unsigned i = 0;  i < pixels.size();  i++) {
        if (*min_x > pixels[i].x) *min_x = pixels[i].x;
        if (*max_x < pixels[i].x) *max_x = pixels[i].x;
        if (*min_y > pixels[i].y) *min_y = pixels[i].y;
        if (*max_y < pixels[i].y) *max_y = pixels[i].y;
    }
}


static void measure(const std::vector<point_t> &pixels, const mdarray_float &data) {
    // pixelsを含む長方形領域を計算
    int min_x, max_x, min_y, max_y;
    containing_box(pixels, &min_x, &max_x, &min_y, &max_y);

    // centroidを計算
    double cx = 0., cy = 0., flux = 0.;
    for (int x = min_x;  x <= max_x;  x++) {
        for (int y = min_y;  y <= max_y;  y++) {
            cx += x * data(x, y);
            cy += y * data(x, y);
            flux += data(x, y);
        }
    }
    cx /= flux;
    cy /= flux;

    // 2次モーメント
    double xx = 0., yy = 0., xy = 0.;
    for (int x = min_x;  x <= max_x;  x++) {
        for (int y = min_y;  y <= max_y;  y++) {
            xx += (cx - x)*(cx - x) * data(x, y);
            yy += (cy - y)*(cy - y) * data(x, y);
            xy += (cx - x)*(cy - y) * data(x, y);
        }
    }
    xx /= flux;
    yy /= flux;
    xy /= flux;

    // a, b, thetaを計算
    double theta = atan2(2.0 * xy, xx - yy) / 2.0,
           t = sqrt((xx - yy)*(xx - yy) + 4 * xy*xy),
           a2 = (xx + yy + t) / 2.,
           b2 = (xx + yy - t) / 2.;
    if (b2 >= 0.) {
        // dataが負になることがあるので2次モーメントも負になることがある
        double a = sqrt(a2),
               b = sqrt(b2);
        printf("% e % e % e % e % e % e\n", cx, cy, flux, a, b, theta);
    }
}


static void pickup_connecting_pixels(fitscc &fits) {
    const unsigned min_area = 20;
    mdarray_float &data = fits.image(0L).float_array();
    mdarray_uchar &mask_ref = fits.image(1L).uchar_array();
    mdarray_uchar mask = mask_ref;

    for (unsigned y = 0;  y < mask.length(1);  y++) {
        for (unsigned x = 0;  x < mask.length(0);  x++) {
            if (mask(x, y) & DETECTED) {
                std::vector<point_t> pixels;     // ピクセル座標を記録する可変長配列
                point_t p;
                p.x = x;
                p.y = y;
                pixels.push_back(p);             // 最初のピクセルをピックアップ
                mask(x, y) &= ~DETECTED;         // チェック済みのピクセルは DETECTED bit を下げる
                for (unsigned done = 0;  done < pixels.size();  done++) {
                    for (int xx = -1;  xx <= 1;  xx++) {            // この2行のループで
                        for (int yy = -1;  yy <= 1;  yy++) {        // 周囲9ピクセル(自分含む)の走査
                            int xxx = pixels[done].x + xx,
                                yyy = pixels[done].y + yy;
                            if (mask(xxx, yyy) & DETECTED) {        // 周囲のピクセルが DETECTED なら
                                point_t p;
                                p.x = xxx;
                                p.y = yyy;
                                pixels.push_back(p);                // ピックアップ
                                mask(xxx, yyy) &= ~DETECTED;        // チェック済みのピクセルは DETECTED bit を下げる
                            }
                        }
                    }
                }
                if (pixels.size() >= min_area) {
                    for (unsigned i = 0;  i < pixels.size();  i++)
                        mask_ref(pixels[i].x, pixels[i].y) |= SOURCE;
                    measure(pixels, data);
                }
            }
        }
    }
}


// OPTIMIZE : もっと省メモリに 
static mdarray_float median_spatial_filter(const mdarray_float &data, int s) {
    mdarray_float stack(false, data.length(0), data.length(1), (2*s + 1) * (2*s + 1));
    int z = 0;
    for (int x = -s;  x <= s;  x++) {
        for (int y = -s;  y <= s;  y++) {
            stack.paste(data, x, y, z);
            z++;
        }
    }
    stack = md_median_small_z(stack);
    return stack;
}


// OPTIMIZE : もっと省メモリに 
static mdarray_float max_spatial_filter(const mdarray_float &data, int s) {
    mdarray_float stack(false, data.length(0), data.length(1), (2*s + 1) * (2*s + 1));
    int z = 0;
    for (int x = -s;  x <= s;  x++) {
        for (int y = -s;  y <= s;  y++) {
            stack.paste(data, x, y, z);
            z++;
        }
    }
    stack = md_max_small_z(stack);
    return stack;
}


// OPTIMIZE : もっと省メモリに 
static mdarray_float min_spatial_filter(const mdarray_float &data, int s) {
    mdarray_float stack(false, data.length(0), data.length(1), (2*s + 1) * (2*s + 1));
    int z = 0;
    for (int x = -s;  x <= s;  x++) {
        for (int y = -s;  y <= s;  y++) {
            stack.paste(data, x, y, z);
            z++;
        }
    }
    stack = md_min_small_z(stack);
    return stack;
}


static void flag_cosmicray(fits_image &image_hdu, fits_image &mask_hdu, double sky_stddev) {
    const double threshold = 0.5;
    mdarray_float &data = image_hdu.float_array();
    mdarray_uchar &mask = mask_hdu.uchar_array();

    mdarray_float sp_min, d_sp_median;
    sp_min = min_spatial_filter(data, 1);

    d_sp_median = data;
    for (int i = 0;  i < 3;  i++)
        d_sp_median -= median_spatial_filter(d_sp_median, 3);

    d_sp_median = max_spatial_filter(d_sp_median, 1);

    double stddev = md_stddev(d_sp_median),
           mean   = md_mean(d_sp_median);

    for (unsigned y = 0;  y < data.length(1);  y++) {
        for (unsigned x = 0;  x < data.length(0);  x++) {
            if ((d_sp_median(x, y) - mean) > threshold * stddev && sp_min(x, y) < sky_stddev)
                mask(x, y) |= COSMICRAY;
        }
    }
}


static void repair_cosmicray(fitscc &fits) {
    mdarray_float &data = fits.image(0L).float_array();
    mdarray_uchar &mask = fits.image(1L).uchar_array();

    for (unsigned y = 0;  y < mask.length(1);  y++) {
        for (unsigned x = 0;  x < mask.length(0);  x++) {
            if (mask(x, y) & COSMICRAY)
                data(x, y) = NAN;
        }
    }

    int nan_count;
    do {
        nan_count = 0;
        mdarray_float median;
        median = median_spatial_filter(data, 2);
        for (unsigned y = 0;  y < data.length(1);  y++) {
            for (unsigned x = 0;  x < data.length(0);  x++) {
                if (isnan(data(x, y))) {
                    nan_count++;
                    data(x, y) = median(x, y);
                }
            }
        }
        sli__eprintf("repairing: %d NAN pixesl\n", nan_count);
    } while(nan_count > 0);
}


int main(int argc, char *argv[])
{
    if (argc != 3) {
        sli__eprintf("usage: %s INPUT OUTPUT\n", argv[0]);
        exit(1);
    }

    double sky_stddev = 0.0; // 警告抑止のため0.0を代入

    fitscc fits;
    fits.read_stream(argv[1]);
    subtract_sky(fits.image(0L), &sky_stddev);
    sli__eprintf("skylevel = 0.0 +/- %f\n", sky_stddev);
    flag_cosmicray(fits.image(0L), fits.image(1L), sky_stddev);
    repair_cosmicray(fits);
    mark_detected(fits, sky_stddev);
    pickup_connecting_pixels(fits);
    fits.write_stream(argv[2]);
    return 0;
}