from __future__ import annotations

from django.core.management.base import BaseCommand
from django.db import transaction

from guests.models import Guest
from services.country_utils import (
    COUNTRY_FIELD_NAMES,
    is_iso2_country,
    normalize_country,
)


EXTRA_DATA_COUNTRY_FIELDS = tuple(COUNTRY_FIELD_NAMES)


class Command(BaseCommand):
    help = "Normalize guest country fields to ISO-2 codes."

    def add_arguments(self, parser):
        parser.add_argument(
            "--apply",
            action="store_true",
            help="Persist normalized values. Without this flag the command runs in dry-run mode.",
        )
        parser.add_argument(
            "--limit",
            type=int,
            default=None,
            help="Optional number of guest rows to inspect.",
        )

    def handle(self, *args, **options):
        apply_changes = options["apply"]
        limit = options.get("limit")

        queryset = Guest.objects.all().order_by("id")
        if limit:
            queryset = queryset[:limit]

        updated_guests = 0
        updated_fields = 0
        unresolved_entries: list[str] = []

        for guest in queryset:
            model_updates = {}
            extra_data = guest.extra_data if isinstance(guest.extra_data, dict) else {}
            extra_updates = extra_data.copy()
            extra_changed = False

            for field_name in COUNTRY_FIELD_NAMES:
                raw_value = getattr(guest, field_name, None)
                normalized = normalize_country(raw_value)

                if not raw_value:
                    continue
                if is_iso2_country(raw_value):
                    continue
                if normalized and is_iso2_country(normalized):
                    if normalized != raw_value:
                        model_updates[field_name] = normalized
                        updated_fields += 1
                else:
                    unresolved_entries.append(
                        f"guest={guest.id} field={field_name} value={raw_value}"
                    )

            for field_name in EXTRA_DATA_COUNTRY_FIELDS:
                raw_value = extra_data.get(field_name)
                normalized = normalize_country(raw_value)

                if not raw_value:
                    continue
                if is_iso2_country(raw_value):
                    continue
                if normalized and is_iso2_country(normalized):
                    if normalized != raw_value:
                        extra_updates[field_name] = normalized
                        extra_changed = True
                        updated_fields += 1
                else:
                    unresolved_entries.append(
                        f"guest={guest.id} extra_data.{field_name} value={raw_value}"
                    )

            if not model_updates and not extra_changed:
                continue

            updated_guests += 1
            if apply_changes:
                with transaction.atomic():
                    for field_name, value in model_updates.items():
                        setattr(guest, field_name, value)
                    if extra_changed:
                        guest.extra_data = extra_updates
                    update_fields = list(model_updates.keys())
                    if extra_changed:
                        update_fields.append("extra_data")
                    guest.save(update_fields=update_fields)

        mode_label = "APPLY" if apply_changes else "DRY-RUN"
        self.stdout.write(f"Mode: {mode_label}")
        self.stdout.write(f"Guests changed: {updated_guests}")
        self.stdout.write(f"Fields normalized: {updated_fields}")
        self.stdout.write(f"Unresolved values: {len(unresolved_entries)}")

        for entry in unresolved_entries[:50]:
            self.stdout.write(f" - {entry}")

        if unresolved_entries and len(unresolved_entries) > 50:
            remaining = len(unresolved_entries) - 50
            self.stdout.write(f" - ... {remaining} more unresolved entries")
