import astropy as ap
import numpy as np
from astropy.io import fits
import os
import math as m
import time as t
# ra = y, dec = x
# I mess around with the ordering of x and y coordinates in pairs, so beware!


def pixels_to_sky(x_coord, y_coord, head, convert=360000):
    """
    Purpose: Convert pixel location in a .fits image to right ascension and declination in 0.01 arcseconds
    :param x_coord: The x-coordinate to be converted
    :param y_coord: The y-coordinate to be converted
    :param head: The FITS header of the image. The centre point of the .fits image must be
    defined; specifically the 'CRPIX1' and 'CRPIX2' in the FITS header (these entries should be in degrees)
    :param convert: Defines a unit conversion away from degrees to the unit of each pixel in the .fits image
    :return: The right ascension and declination of the given point in the image in 0.01 arcseconds as floats
    """
    # right ascension is oriented on y-axis, declination is oriented on x-axis
    y_height = head['CRPIX1']
    x_width = head['CRPIX2']
    ra_ref = head['CRVAL1'] * convert  # 360000 puts them into 0.01 arcseconds, this was convenient for unit pixels
    dec_ref = head['CRVAL2'] * convert
    ra_coord = y_coord - y_height + ra_ref  # distance between RA&Dec in sky is equal to pixel distance in Y&X
    dec_coord = x_coord - x_width + dec_ref
    return dec_coord, ra_coord


def sky_to_pixels(dec_coord, ra_coord, head, convert=360000):
    """
    Purpose: Convert right ascension and declination to pixel location in a .fits image
    :param dec_coord: The declination to be converted in 0.01 arcseconds
    :param ra_coord: The right ascension to be converted in 0.01 arcseconds
    :param head: The FITS header of the image. The centre point of the .fits image must be
    defined; specifically the 'CRPIX1', 'CRPIX2','CRVAL1' and 'CRVAL2' in the FITS header (these entries should be in
    degrees)
    :param convert: Defines a unit conversion away from degrees to the unit of each pixel in the .fits image
    :return: The x- and y-coordinate of the given point in the image
    """
    y_height = head['CRPIX1']
    x_width = head['CRPIX2']
    ra_ref = head['CRVAL1'] * convert
    dec_ref = head['CRVAL2'] * convert
    x_coord = int(dec_coord - dec_ref + x_width)
    y_coord = int(ra_coord - ra_ref + y_height)
    return x_coord, y_coord


def write_point(x_coord, y_coord, data, mag=10):
    """
    Purpose: Write a point onto a .fits file image array. The image must already be opened.
    :param x_coord: x-coordinate in pixels of the point to be written
    :param y_coord: y-coordinate in pixels of the point to be written
    :param data: The data of the .fits image stored as an array
    :param mag: Magnitude of the image as a float. Default is 10
    :return: (none)
    """
    # for our procedure we had normalized magnitudes, so we set pixel values arbitrarily (but similar)
    data[y_coord, x_coord] = 2.512 ** mag
    return


def get_info():
    """
    Purpose: Prompt user input of the name of a file with each image RA, Dec (both in degrees), and magnitude
    comma-separated, with each image on a new line, and return a list lists of these values for each image
    :return: A list of lists, with each list corresponding to a single lensed image with the RA, Dec, and magnitude, as
    well as the name of the input file
    """
    chars = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", ".", "-", "\n", ","]
    image_info = input('Please input the name of a file with the appropriate information and format: ')
    raw = open(image_info+'.csv', 'r')  # pretty basic file reading stuff
    boss = []
    for line in raw:
        for item in line:
            if item not in chars:  # check to ensure characters are only coordinates
                print('Invalid character in .csv file found. Check for whitespaces or letters '
                      '(recommend opening in Notepad rather than a spreadsheet program). Program closing now.')
                quit()
        slave = line.rstrip().split(',')
        boss.append(slave)

    return boss, image_info


def write_all(length, convert=360000):
    """
    Purpose: Makes a .fits image and inputs points representing quasar images, optionally with relative magnitudes
    :param length: The length/width of the fits image to be created
    :param convert: Defines a unit conversion away from degrees to the unit of each pixel in the .fits image
    :return: The file name of a new spoofed .fits file, as well as a list-of-lists containing the pixel locations of
    each image in (x,y) format
    """
    books = []  # need this for 'bookkeeping' later.
    info, quasar_name = get_info()
    data = np.zeros((int(length), int(length)))  # create array of appropriate size
    head = fits.Header()  # create the header
    height, width = data.shape
    y_height = height // 2  # find middle of image
    x_width = width // 2
    # if any other header parts need to be written, this is the place to do it
    head['CRPIX1'] = y_height  # arbitrarily puts the first image at the centre
    head['CRPIX2'] = x_width
    image_prime = info[0]
    # ra_prime = int(float(image_prime[0]) * convert)
    # dec_prime = int(float(image_prime[1]) * convert)
    books.append([y_height, x_width])
    head['CRVAL1'] = float(image_prime[0])  # in degrees
    head['CRVAL2'] = float(image_prime[1])
    head['CUNIT1'] = '0.01 arcseconds'
    head['CUNIT2'] = '0.01 arcseconds'
    write_point(x_width, y_height, data)  # prime image is written and coordinate system defined
    for image in info[1:]:  # no need to take first image here; it is already written!
        ra_image = int(float(image[0]) * convert)
        dec_image = int(float(image[1]) * convert)
        # mag = float(image[2])  # not using this right now, just normalizing image
        x_image, y_image = sky_to_pixels(dec_image, ra_image, head)
        write_point(x_image, y_image, data)
        books.append([y_image, x_image])
    fits.writeto("spoofed_" + quasar_name + '.fits', data, head)
    print("A new spoofed file should be available in the root folder.")
    return "spoofed_" + quasar_name + '.fits', books


def semi_minor(a, x, y, x_c, y_c, angle):
    """
    Calculates the sum of the errors of putting the points to be fitted into the ellipse equation with given
    coefficients
    :param a: The semi-major axis of the ellipse
    :param x: The x-coordinate of the point being calculated
    :param y: The y-coordinate of the point being calculated
    :param x_c: The x-coordinate of the center point
    :param y_c: The y-coordinate of the center point
    :param angle: The angle between the positive x-axis and the center point, rotated counter-clockwise around the tip
    :return: The semi-minor axis as calculated at this point, None if a mathematical singularity is found
    """
    b = None
    numerator = (- (a ** 2) * ((y_c * m.cos(angle) - y * m.cos(angle) - x_c * m.sin(angle) + x * m.sin(angle)) ** 2))
    denominator = (-(a ** 2)
                   + (x_c ** 2) * (m.cos(angle) ** 2) - 2 * x_c * x * (m.cos(angle) ** 2)
                   + (x ** 2) * (m.cos(angle) ** 2)
                   + 2 * x_c * y_c * m.cos(angle) * m.sin(angle)
                   - 2 * x * y_c * m.cos(angle) * m.sin(angle)
                   - 2 * x_c * y * m.cos(angle) * m.sin(angle)
                   + 2 * x * y * m.cos(angle) * m.sin(angle)
                   + (y_c ** 2) * (m.sin(angle) ** 2)
                   - 2 * y_c * y * (m.sin(angle) ** 2)
                   + (y ** 2) * (m.sin(angle) ** 2))

    if a != 0 and denominator != 0:  # check for mathematical singularities
        argument = numerator / denominator
        if argument > 0:
            b = m.sqrt(argument)  # from equation for an ellipse, derived using Mathematica
    return b


def ellipse_insanity(books, error=10):
    """
    Fits an ellipse to a lensed quasar image via a brute-force algorithm
    :param books: A list-of-lists containing the y,x coordinates of each quasar image
    :param error: The error range in pixels while comparing semi-minor axis calculations for each point
    :return: The a list of lists of semi-major, semi-minor, center-x, center-y and theta values for fitting ellipses
    """
    # figure out how large the major axis should be, and decide which of the two extremum points is the tip
    x_t = int(input("Please input the FITS X of the first opposite image: "))
    y_t = int(input("Please input the FITS Y of the first opposite image: "))
    x_tw = int(input("Please input the FITS X of the second opposite image: "))
    y_tw = int(input("Please input the FITS Y of the second opposite image: "))
    opposites = m.sqrt((x_t - x_tw) ** 2 + (y_t - y_tw) ** 2)
    start_time = t.time()  # just for fun
    a_distance = 0
    b_distance = 0
    for item in books:
        a_distance += m.sqrt((x_t - item[1]) ** 2 + (y_t - item[0]) ** 2)
        b_distance += m.sqrt((x_tw - item[1]) ** 2 + (y_tw - item[0]) ** 2)
    if a_distance > b_distance:
        tip = y_t, x_t
    else:
        tip = y_tw, x_tw

    # determine max and min values of theta
    omegas = []
    for item in books:
        shorta = item[1]-tip[1]  # dec difference
        shortb = item[0]-tip[0]  # ra difference
        if shorta != 0 and shortb != 0:
            longc = m.sqrt(shorta ** 2 + shortb ** 2)  # point distance
            bigsee = m.acos((longc ** 2 - shorta ** 2 - shortb ** 2)/(-2 * shorta * shortb))  # law of cosines
            omegas.append(bigsee)
    min_omega = min(omegas)
    max_omega = max(omegas)

    #  bounds for searching for the centre (it must be inside the image shape)
    y_vals = [item[0] for item in books]
    x_vals = [item[1] for item in books]
    x_max = max(x_vals)
    x_min = min(x_vals)
    y_max = max(y_vals)
    y_min = min(y_vals)

    positives = []
    calculations = 0
    # brute-force the centre point of the ellipse, increasing allowed error in minor axis calc by 5 each time
    # zero results are found
    while len(positives) == 0 and error < 100:
        for y_c in range(y_min, y_max + 1):
            for x_c in range(x_min, x_max + 1):
                # calculate theta with law of cosines
                part1 = x_c - tip[1]  # dec difference
                part2 = y_c - tip[0]  # ra difference
                part3 = m.sqrt(part1 ** 2 + part2 ** 2)
                if part1 != 0 and part2 != 0:
                    theta = m.acos((part3 ** 2 - part1 ** 2 - part2 ** 2)/(-2 * part1 * part2))
                elif part1 == 0 and tip[0] < y_c:
                    theta = m.pi/2  # special case where theta is 90 degrees (tip is below center)
                elif part1 == 0 and tip[0] > y_c:
                    theta = m.pi / 2  # special case where theta is 270 degrees (tip is above center)
                elif part2 == 0 and tip[1] > x_c:
                    theta = m.pi  # special case where theta is 180 degrees (tip is right of center)
                else:
                    theta = 0  # special case where theta is 0 degrees (tip is left of center)

                a = m.sqrt((tip[1] - x_c) ** 2 + (tip[0] - y_c) ** 2)
                b_main = semi_minor(a, tip[1], tip[0], x_c, y_c, theta)
                mini_flag = []
                for item in books:  # calculates the theoretical minor axis for all points and compares them
                    b_item = semi_minor(a, item[1], item[0], x_c, y_c, theta)
                    if b_item is None or b_main is None:
                        mini_flag.append(False)
                    else:  # checks if b is within error, both b's are not None, theta is within bounds, and a is
                        # sufficiently large to make physical sense (larger than half the opposite points distance)
                        mini_flag.append(abs(b_main-b_item) <= error and a >= 0.5 * opposites
                                         and max_omega >= theta >= min_omega)

                if False not in mini_flag:
                    positives.append([a, b_main, x_c, y_c, theta, a - 0.5 * opposites])
                calculations += 1
        error += 5

    # just some statistics for testing
    print("\nFound " + str(len(positives)) + " ellipses in " + str((t.time() - start_time)) +
          " seconds, with semi-minor axis error of " + str(error - 5))

    if error >= 100:
        print("Minor-axis error has gotten too large. System is unable to be fitted by this algorithm. Program closing")
        quit()

    # pick average result for each parameter
    ave_result = []
    for i in range(0, 5):  # average all values
        mini_result = []
        for j in positives:
            mini_result.append(j[i])
        ave_result.append(sum(mini_result) / len(mini_result))  # probably the reason results are mainly circular

    return ave_result


def lens_inversion(image, x_c, y_c, convert=360000):
    """
    Purpose: Finds the coordinates of the lens in degrees by reflecting the luminosity centroid location through the
    center point of the fitted ellipse
    :param image: The image of the lensed quasar
    :param x_c: The x-coordinate of the centre of the drawn ellipse
    :param y_c: The y-coordinate of the centre of the drawn ellipse
    :param convert: Defines a unit conversion away from the units in the .fits file, which should have unitary pixels
    :return: The right ascension and declination of the lens in degrees, as well as the pixel location of the lens
    """
    x_centroid = int(input("Please input the FITS X of the centroid: "))
    y_centroid = int(input("Please input the FITS Y of the centroid: "))
    x_lens = 2 * x_c - x_centroid  # flip centroid over major and minor axes to get lens location
    y_lens = 2 * y_c - y_centroid
    hdul = fits.open(image)  # convert to astronomy coordinates
    head = hdul[0].header
    dec_lens, ra_lens = pixels_to_sky(x_lens, y_lens, head)
    hdul.close()
    return ra_lens / convert, dec_lens / convert, x_lens, y_lens


def ellipse_draw(ellipse, spoofed, books, length=1200, error=0.01):
    """
    Trace an ellipse onto an array
    :param ellipse: A list of semi-major axis, semi-minor axis, centre-x, centre-y, and tilt angle of an ellipse
    :param spoofed: The filename of the spoofed ellipse
    :param books: The points onto which the ellipse fits
    :param length: The size of the image
    :param error: The error range in which to draw
    :return: (none)
    """
    new_data = np.zeros((int(length), int(length)))
    a = ellipse[0]
    b = ellipse[1]
    x_c = ellipse[2]
    y_c = ellipse[3]
    theta = ellipse[4]
    phase = 0
    # Find absolute bounds for ellipse
    y_vals = [item[0] for item in books]
    x_vals = [item[1] for item in books]
    x_max = max(x_vals)
    x_min = min(x_vals)
    y_max = max(y_vals)
    y_min = min(y_vals)
    for y in range(y_min-80, y_max+81):  # need large bounds to get a really visible image
        for x in range(x_min-80, x_max+81):
            if abs(((((x - x_c) * m.cos(theta+phase)+(y-y_c) * m.sin(theta+phase)) ** 2) / a ** 2
                    + (((x - x_c) * m.sin(theta+phase)-(y-y_c) * m.cos(theta+phase)) ** 2) / b ** 2)-1) <= error:
                new_data[y, x] = 100
    fits.writeto(str(spoofed) + "_drawn.fits", new_data)
    return


def master_function(draw=False):
    """
    Purpose: Encompass all functions for lens inversion and allow user to perform the operation for multiple lensed
    quasars or files.
    :param draw: Toggles the drawing function
    :return: (none)
    """
    # 1200 chosen to represent 0.2' (600 pixels = 0.1', 1 pixel = 0.01")
    terminate = 'n'
    while terminate != 'y':
        spoofed, images = write_all(1200)
        print("Please open the new spoofed file in AstroImageJ.")
        ellipse = ellipse_insanity(images)
        x_c = ellipse[2]
        y_c = ellipse[3]
        ra_lens, dec_lens, x_lens, y_lens = lens_inversion(spoofed, x_c, y_c)
        print("\nLens information: ")
        print("Right Ascension in Degrees: " + str(ra_lens))
        print("Declination in Degrees: " + str(dec_lens))
        print("X-coordinate on image: " + str(int(x_lens)))
        print("Y-coordinate on image: " + str(int(y_lens)))
        if draw:
            ellipse_draw(ellipse, spoofed, images)
        terminate = input("If you wish to exit this session, type 'y'. Other wise, type 'n' or another letter to solve "
                          "a new lensed system. ")
    return


master_function(draw=True)
print("Program ran to completion.")
