"""Module for the .source file reader."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['read_sources']

# 1. Standard Python modules
from datetime import datetime
import itertools
import logging
from pathlib import Path
from typing import TextIO, TypeAlias

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.ptmio.file_reader import FileReader
from xms.ptmio.source.sources import (
    InstantMassInstruction, InstantMassSource, LineMassDatum, LineMassInstruction, LineMassSource, PointMassInstruction,
    PointMassSource, PolygonMassInstruction, PolygonMassSource, Sources
)

AnySource: TypeAlias = PointMassSource | InstantMassSource | LineMassSource | PolygonMassSource


def read_sources(where: str | Path | TextIO) -> Sources:
    """
    Read a .sources file.

    Args:
        where: The file or stream to read from.

    Returns:
        Sources contained in the file.
    """
    with FileReader(where) as reader:
        sources = read_all_sources(reader)
    return sources


def read_all_sources(reader: FileReader) -> Sources:
    """Read all the sources."""
    sources = Sources()
    definitions: dict[int, int] = {}

    read_instant_sources(reader, definitions, sources)
    read_point_sources(reader, definitions, sources)
    read_line_sources(reader, definitions, sources)
    read_polygon_sources(reader, definitions, sources)
    renumber_sources(sources)

    return sources


def read_instant_sources(reader: FileReader, definitions: dict[int, int], sources: Sources):
    """Read all the instant sources."""
    num_sources = reader.read_int()
    reader.next_line()
    for _ in range(num_sources):
        read_instant_source(reader, definitions, sources)


def read_instant_source(reader: FileReader, definitions: dict[int, int], sources: Sources):
    """Read a single instant source."""
    source_id = reader.read_int()
    check_duplicate_ids(reader.line_number, source_id, definitions)
    num_instructions = reader.read_int()
    label = reader.read_remainder()
    reader.next_line()

    label = label.strip('\r\n\t "')
    instructions = []

    for _ in range(num_instructions):
        year, month, day = reader.read_int(count=3)
        hour, minute, second = reader.read_int(count=3)
        x, y, z = reader.read_float(count=3)
        parcel_mass, h_radius, v_radius, source_mass = reader.read_float(count=4)
        grain_size, stdev, density = reader.read_float(count=3)
        velocity, initiation, deposition = reader.read_float(count=3)
        reader.next_line()

        time = datetime(year=year, month=month, day=day, hour=hour, minute=minute, second=second)
        instruction = InstantMassInstruction(
            time=time,
            location=(x, y, z),
            parcel_mass=parcel_mass,
            h_radius=h_radius,
            v_radius=v_radius,
            mass_rate=source_mass,
            grain_size=grain_size,
            stdev=stdev,
            density=density,
            velocity=velocity,
            initiation=initiation,
            deposition=deposition
        )
        instructions.append(instruction)

    source = InstantMassSource(source_id=source_id, label=label, instructions=instructions)
    sources.instant_sources.append(source)


def read_point_sources(reader: FileReader, definitions: dict[int, int], sources: Sources):
    """Read all the point sources."""
    num_sources = reader.read_int()
    reader.next_line()
    for _ in range(num_sources):
        read_point_source(reader, definitions, sources)


def read_point_source(reader: FileReader, definitions: dict[int, int], sources: Sources):
    """Read a single point source."""
    source_id, num_instructions = reader.read_int(count=2)
    check_duplicate_ids(reader.line_number, source_id, definitions)
    label = reader.read_remainder()
    reader.next_line()

    label = label.strip('\r\n\t "')
    instructions = []

    for _ in range(num_instructions):
        year, month, day = reader.read_int(count=3)
        hour, minute, second = reader.read_int(count=3)
        x, y, z = reader.read_float(count=3)
        source_mass, h_radius, v_radius, mass_rate = reader.read_float(count=4)
        grain_size, stdev, density = reader.read_float(count=3)
        velocity, initiation, deposition = reader.read_float(count=3)
        reader.next_line()

        time = datetime(year=year, month=month, day=day, hour=hour, minute=minute, second=second)
        instruction = PointMassInstruction(
            time=time,
            location=(x, y, z),
            parcel_mass=source_mass,
            h_radius=h_radius,
            v_radius=v_radius,
            mass_rate=mass_rate,
            grain_size=grain_size,
            stdev=stdev,
            density=density,
            velocity=velocity,
            initiation=initiation,
            deposition=deposition
        )
        instructions.append(instruction)

    source = PointMassSource(source_id=source_id, label=label, instructions=instructions)
    sources.point_sources.append(source)


def read_line_sources(reader: FileReader, definitions: dict[int, int], sources: Sources):
    """Read all the line sources."""
    num_sources = reader.read_int()
    reader.next_line()
    for _ in range(num_sources):
        read_line_source(reader, definitions, sources)


def read_line_source(reader: FileReader, definitions: dict[int, int], sources: Sources):
    """Read a single line source."""
    source_id, num_instructions = reader.read_int(count=2)
    check_duplicate_ids(reader.line_number, source_id, definitions)
    label = reader.read_str()
    datum = reader.read_str(optional=True)

    if datum == 'beddatum':
        datum = LineMassDatum.bed_datum
    elif datum == 'surfacedatum':
        datum = LineMassDatum.surface_datum
    elif datum == 'depthdistributed':
        datum = LineMassDatum.depth_distributed
    elif datum == '':
        datum = LineMassDatum.none
    else:
        raise reader.error('Datum must be omitted or one of beddatum, surfacedatum, or depthdistributed.')

    reader.next_line()  # Must come after validating datum so errors are reported to the right line/field.

    label = label.strip('\r\n\t "')
    instructions = []

    for _ in range(num_instructions):
        year, month, day = reader.read_int(count=3)
        hour, minute, second = reader.read_int(count=3)
        x1, y1, z1 = reader.read_float(count=3)
        x2, y2, z2 = reader.read_float(count=3)
        source_mass, h_radius, v_radius, mass_rate = reader.read_float(count=4)
        grain_size, stdev, density = reader.read_float(count=3)
        velocity, initiation, deposition = reader.read_float(count=3)
        reader.next_line()

        time = datetime(year=year, month=month, day=day, hour=hour, minute=minute, second=second)
        instruction = LineMassInstruction(
            time=time,
            start=(x1, y1, z1),
            end=(x2, y2, z2),
            parcel_mass=source_mass,
            h_radius=h_radius,
            v_radius=v_radius,
            mass_rate=mass_rate,
            grain_size=grain_size,
            stdev=stdev,
            density=density,
            velocity=velocity,
            initiation=initiation,
            deposition=deposition
        )
        instructions.append(instruction)

    source = LineMassSource(source_id=source_id, label=label, datum=datum, instructions=instructions)
    sources.line_sources.append(source)


def read_polygon_sources(reader: FileReader, definitions: dict[int, int], sources: Sources):
    """Read all the polygon sources."""
    num_sources = reader.read_int()
    reader.next_line()

    for _ in range(num_sources):
        read_polygon_source(reader, definitions, sources)


def read_polygon_source(reader: FileReader, definitions: dict[int, int], sources: Sources):
    """Read a single polygon source."""
    source_id, num_points, num_instructions = reader.read_int(count=3)
    check_duplicate_ids(reader.line_number, source_id, definitions)
    label = reader.read_remainder()
    reader.next_line()

    label = label.strip('\r\n\t "')
    instructions = []

    for _ in range(num_instructions):
        year, month, day = reader.read_int(count=3)
        hour, minute, second = reader.read_int(count=3)
        time = datetime(year=year, month=month, day=day, hour=hour, minute=minute, second=second)
        reader.next_line()

        points = []
        for _ in range(num_points):
            x, y, z = reader.read_float(count=3)
            reader.next_line()
            points.append((x, y, z))

        source_mass, h_radius, v_radius, mass_rate = reader.read_float(count=4)
        grain_size, stdev, density = reader.read_float(count=3)
        velocity, initiation, deposition = reader.read_float(count=3)
        reader.next_line()

        instruction = PolygonMassInstruction(
            time=time,
            points=points,
            parcel_mass=source_mass,
            h_radius=h_radius,
            v_radius=v_radius,
            mass_rate=mass_rate,
            grain_size=grain_size,
            stdev=stdev,
            density=density,
            velocity=velocity,
            initiation=initiation,
            deposition=deposition
        )
        instructions.append(instruction)

    source = PolygonMassSource(source_id=source_id, label=label, instructions=instructions)
    sources.polygon_sources.append(source)


def check_duplicate_ids(current_line: int, new_id: int, definitions: dict[int, int]):
    """Check if a source has a duplicate ID and log a warning if so."""
    if new_id not in definitions:
        definitions[new_id] = current_line
        return

    previous_line = definitions[new_id]
    log = logging.getLogger('xms.ptm.io')
    log.warning(
        f'The source defined on line {current_line} has the same ID as the source defined on line {previous_line}. '
        "The source's ID will be renumbered to be unique."
    )


def renumber_sources(sources: Sources):
    """Ensure every source has a unique ID."""
    all_sources: list[AnySource] = []
    all_sources.extend(sources.instant_sources)
    all_sources.extend(sources.point_sources)
    all_sources.extend(sources.line_sources)
    all_sources.extend(sources.polygon_sources)

    used_ids_for_unused_ids_generator = [source.source_id for source in all_sources]
    unused_ids = (i for i in itertools.count(start=1) if i not in used_ids_for_unused_ids_generator)

    used_ids = set()
    for source in all_sources:
        if source.source_id in used_ids:
            source.source_id = next(unused_ids)
        else:
            used_ids.add(source.source_id)
