from __future__ import annotations

import csv
import io
from pathlib import Path
from zipfile import ZipFile

from django.core.management.base import BaseCommand, CommandError
from django.db import transaction

from istat.models import IstatMunicipality
from istat.municipalities import normalize_municipality_name, normalize_province_code


def _clean_key(value):
    return str(value or "").strip().lower()


def _clean_value(value):
    return str(value or "").strip() if value is not None else ""


def _normalized_row(row):
    return {_clean_key(key): _clean_value(value) for key, value in row.items()}


def _decode(raw):
    for encoding in ("utf-8-sig", "cp1252", "latin1"):
        try:
            return raw.decode(encoding)
        except UnicodeDecodeError:
            continue
    return raw.decode("latin1", errors="replace")


def _read_csv_rows(raw):
    text = _decode(raw)
    sample = text[:4096]
    try:
        dialect = csv.Sniffer().sniff(sample, delimiters=";,|\t")
    except csv.Error:
        dialect = csv.excel
    reader = csv.DictReader(io.StringIO(text), dialect=dialect)
    for row in reader:
        if row:
            yield _normalized_row(row)


def _read_xlsx_rows(raw):
    try:
        from openpyxl import load_workbook
    except ImportError as exc:
        raise CommandError("openpyxl is required to import XLSX municipality files.") from exc

    workbook = load_workbook(io.BytesIO(raw), read_only=True, data_only=True)
    worksheet = workbook.active
    rows = worksheet.iter_rows(values_only=True)
    try:
        headers = next(rows)
    except StopIteration:
        return
    keys = [_clean_key(header) for header in headers]
    for values in rows:
        row = {key: _clean_value(value) for key, value in zip(keys, values)}
        if any(row.values()):
            yield row


def _read_rows(path):
    suffix = path.suffix.lower()
    if suffix == ".zip":
        with ZipFile(path) as archive:
            member = next(
                (
                    name
                    for name in archive.namelist()
                    if name.lower().endswith((".csv", ".xlsx"))
                ),
                None,
            )
            if member is None:
                raise CommandError(f"No CSV or XLSX file found inside {path}")
            raw = archive.read(member)
            if member.lower().endswith(".xlsx"):
                yield from _read_xlsx_rows(raw)
            else:
                yield from _read_csv_rows(raw)
        return

    raw = path.read_bytes()
    if suffix == ".xlsx":
        yield from _read_xlsx_rows(raw)
    else:
        yield from _read_csv_rows(raw)


def _first(row, *keys):
    for key in keys:
        value = row.get(key)
        if value:
            return value
    return ""


def _municipality_payload_from_row(row):
    code = _first(
        row,
        "codice comune formato alfanumerico",
        "codice comune formato numerico",
        "codice comune",
        "codice_comune",
        "codice istat",
        "codice_istat",
        "code",
        "codice",
    )
    name = _first(
        row,
        "denominazione in italiano",
        "denominazione comune",
        "denominazione",
        "comune",
        "name",
    )
    province = _first(
        row,
        "sigla automobilistica",
        "sigla provincia",
        "sigla_provincia",
        "provincia",
        "province",
    )
    region = _first(row, "denominazione regione", "regione", "region")

    if not code or not name:
        return None

    normalized_code = str(code).strip()
    if normalized_code.isdigit():
        normalized_code = normalized_code.zfill(6)
    if len(normalized_code) == 6:
        normalized_code = f"{normalized_code}000"
    elif len(normalized_code) < 9 and normalized_code.isdigit():
        normalized_code = normalized_code.zfill(9)

    return {
        "code": normalized_code,
        "name": name,
        "normalized_name": normalize_municipality_name(name),
        "province": normalize_province_code(province) or None,
        "region": region or None,
        "is_active": True,
    }


class Command(BaseCommand):
    help = "Import official ISTAT Italian municipality data into the database."

    def add_arguments(self, parser):
        parser.add_argument("--file", required=True, type=str, help="Official ISTAT CSV/XLSX/ZIP file")
        parser.add_argument(
            "--deactivate-missing",
            action="store_true",
            help="Mark existing municipalities not present in the file as inactive.",
        )

    def handle(self, *args, **options):
        path = Path(options["file"])
        if not path.exists():
            raise CommandError(f"Municipality file not found: {path}")

        payloads = []
        seen_codes = set()
        for row in _read_rows(path):
            payload = _municipality_payload_from_row(row)
            if not payload or payload["code"] in seen_codes:
                continue
            seen_codes.add(payload["code"])
            payloads.append(payload)

        if not payloads:
            raise CommandError("No municipality rows could be parsed from the file.")

        with transaction.atomic():
            existing_by_code = {
                municipality.code: municipality
                for municipality in IstatMunicipality.objects.filter(code__in=seen_codes)
            }
            creates = []
            updates = []
            for payload in payloads:
                municipality = existing_by_code.get(payload["code"])
                if municipality is None:
                    creates.append(IstatMunicipality(**payload))
                    continue
                for field, value in payload.items():
                    setattr(municipality, field, value)
                updates.append(municipality)

            if creates:
                IstatMunicipality.objects.bulk_create(creates, batch_size=1000)
            if updates:
                IstatMunicipality.objects.bulk_update(
                    updates,
                    ["name", "normalized_name", "province", "region", "is_active"],
                    batch_size=1000,
                )

            deactivated = 0
            if options["deactivate_missing"]:
                deactivated = IstatMunicipality.objects.exclude(
                    code__in=seen_codes
                ).update(is_active=False)

        self.stdout.write(
            self.style.SUCCESS(
                "Imported ISTAT municipalities: "
                f"{len(creates)} created, {len(updates)} updated, {deactivated} deactivated."
            )
        )
