#include <sli/fitscc.h>
#include <sli/mdarray_statistics.h>
#include <stdlib.h>
#include "mask.h"


using namespace sli;


static void subtract_sky(fits_image &hdu, double *sky_stddev) {
    mdarray_float data = hdu.float_array();
    const int iter_n = 5;
    const double sigma = 2.0;
    for (int i = 0;  i < iter_n;  i++) {
        double mean = md_mean(data),
               stddev = md_stddev(data);
        for (unsigned y = 0;  y < data.length(1);  y++) {
            for (unsigned x = 0;  x < data.length(0);  x++) {
                if ((data(x, y) - mean) / stddev > sigma)
                    data(x, y) = NAN;
            }
        }
    }
    hdu.float_array() -= md_mean(data);
    *sky_stddev = md_stddev(data);
}


static void mark_detected(fitscc &fits, double sky_stddev) {
    mdarray_float &data = fits.image(0L).float_array();

    if (fits.length() < 2)
        fits.append_image("Mask", 0, FITS::BYTE_T, data.length(0), data.length(1));

    mdarray_uchar &mask = fits.image(1L).uchar_array();

    // ある閾値より高い値をもつピクセルをDETECTEDにマークするようここに実装
}


int main(int argc, char *argv[])
{
    if (argc != 3) {
        sli__eprintf("usage: %s INPUT OUTPUT\n", argv[0]);
        exit(1);
    }

    double sky_stddev;

    fitscc fits;
    fits.read_stream(argv[1]);
    subtract_sky(fits.image(0L), &sky_stddev);
    sli__eprintf("skylevel = 0.0 +/- %f\n", sky_stddev);
    mark_detected(fits, sky_stddev);
    fits.write_stream(argv[2]);
    return 0;
}