from __future__ import annotations

from dataclasses import dataclass
from typing import Final

from django.core.management.base import BaseCommand, CommandError
from django.db import transaction

from istat.models import IstatCountry


@dataclass(frozen=True)
class CountrySeed:
    name: str
    code: str
    iso_code: str


COUNTRY_SEEDS: Final[tuple[CountrySeed, ...]] = (
    CountrySeed(name="Australia", code="AUS", iso_code="AU"),
    CountrySeed(name="Brazil", code="BRA", iso_code="BR"),
    CountrySeed(name="Canada", code="CAN", iso_code="CA"),
    CountrySeed(name="China", code="CHN", iso_code="CN"),
    CountrySeed(name="France", code="FRA", iso_code="FR"),
    CountrySeed(name="Germany", code="DEU", iso_code="DE"),
    CountrySeed(name="India", code="IND", iso_code="IN"),
    CountrySeed(name="Italy", code="ITA", iso_code="IT"),
    CountrySeed(name="Japan", code="JPN", iso_code="JP"),
    CountrySeed(name="Netherlands", code="NLD", iso_code="NL"),
    CountrySeed(name="Portugal", code="PRT", iso_code="PT"),
    CountrySeed(name="Spain", code="ESP", iso_code="ES"),
    CountrySeed(name="Switzerland", code="CHE", iso_code="CH"),
    CountrySeed(name="United Arab Emirates", code="ARE", iso_code="AE"),
    CountrySeed(name="United Kingdom", code="GBR", iso_code="GB"),
    CountrySeed(name="United States", code="USA", iso_code="US"),
)


def _normalize_code(value: str | None) -> str:
    return (value or "").strip().upper()


def _build_country_instances(
    seeds: tuple[CountrySeed, ...],
    existing_codes: set[str],
    existing_iso_codes: set[str],
) -> tuple[list[IstatCountry], int]:
    to_create: list[IstatCountry] = []
    skipped = 0

    seen_codes = set(existing_codes)
    seen_iso_codes = set(existing_iso_codes)

    for seed in seeds:
        normalized_code = _normalize_code(seed.code)
        normalized_iso_code = _normalize_code(seed.iso_code)

        if normalized_code in seen_codes or normalized_iso_code in seen_iso_codes:
            skipped += 1
            continue

        to_create.append(
            IstatCountry(
                name=seed.name.strip(),
                code=normalized_code,
                iso_code=normalized_iso_code,
            )
        )
        seen_codes.add(normalized_code)
        seen_iso_codes.add(normalized_iso_code)

    return to_create, skipped


class Command(BaseCommand):
    help = "Seed a base set of countries into ISTAT country references."

    def handle(self, *args, **options) -> None:
        try:
            existing_codes = {
                _normalize_code(code)
                for code in IstatCountry.objects.values_list("code", flat=True)
            }
            existing_iso_codes = {
                _normalize_code(iso_code)
                for iso_code in IstatCountry.objects.exclude(iso_code__isnull=True)
                .exclude(iso_code__exact="")
                .values_list("iso_code", flat=True)
            }

            countries_to_create, skipped_count = _build_country_instances(
                seeds=COUNTRY_SEEDS,
                existing_codes=existing_codes,
                existing_iso_codes=existing_iso_codes,
            )

            with transaction.atomic():
                IstatCountry.objects.bulk_create(
                    countries_to_create,
                    ignore_conflicts=True,
                )

            self.stdout.write(f"Created: {len(countries_to_create)} countries")
            self.stdout.write(f"Skipped: {skipped_count} already exist")
        except Exception as exc:
            raise CommandError(f"Failed to seed countries: {exc}") from exc
