from __future__ import annotations
from sqlalchemy.orm import Session
from geoalchemy2.shape import from_shape
from shapely.geometry import Point
from .celery_app import celery_app
from .db import SessionLocal
from .storage import read_bytes
import io
from . import inference, rules
import os
from sqlalchemy import text
from datetime import datetime, timezone

# Import models by reflecting minimal SQL statements (keeps worker lightweight)

@celery_app.task(name="worker.process_media")
def process_media(media_id: str):
    db = SessionLocal()
    try:
        # Load media row
        row = db.execute(text("""
            SELECT m.id, m.s3_key, m.media_type, ST_AsText(m.gps) as gps_wkt, m.site_id, m.uploader_id
            FROM media m
            WHERE m.id = :mid
        """), {"mid": media_id}).mappings().first()
        if not row:
            return {"ok": False, "error": "Media not found"}

        media_bytes = read_bytes(row["s3_key"])

        # Convert to image for stub inference
        if row["media_type"] in ("VIDEO", "DRONE_VIDEO"):
            img = inference.extract_frame_from_video(media_bytes, at_seconds=1)
            frame_time_ms = 1000
        else:
            from PIL import Image
            img = Image.open(io.BytesIO(media_bytes)).convert("RGB")  # type: ignore
            frame_time_ms = None

        dets = inference.run_stub(img)

        # Determine pin from gps or fallback (0,0)
        lat, lng = None, None
        if row["gps_wkt"] and row["gps_wkt"].startswith("POINT"):
            inside = row["gps_wkt"][row["gps_wkt"].find("(")+1:row["gps_wkt"].find(")")]
            lng, lat = [float(x) for x in inside.split()]
        if lat is None or lng is None:
            lat, lng = 0.0, 0.0

        # Insert detections + hazards
        for d in dets:
            signal = d["signal"]
            conf = float(d.get("confidence", 0.5))
            hz_type, title, sev = rules.hazard_from_signal(signal)
            tags = rules.rule_tags_for(signal)

            # detection
            db.execute(text("""
                INSERT INTO detections (id, media_id, frame_time_ms, hazard_signal, confidence, severity, bbox, polygon, rule_tags, pin, created_at)
                VALUES (gen_random_uuid()::text, :media_id, :ft, :sig, :conf, :sev, NULL, NULL, :tags::jsonb,
                        ST_SetSRID(ST_MakePoint(:lng, :lat),4326)::geography, now())
            """), {"media_id": media_id, "ft": frame_time_ms, "sig": signal, "conf": conf, "sev": sev, "tags": str(tags).replace("'", '"'), "lng": lng, "lat": lat})

            # hazard case
            db.execute(text("""
                INSERT INTO hazards (id, site_id, status, hazard_type, title, description, severity, primary_rule_tags, pin, created_by, created_at, updated_at)
                VALUES (gen_random_uuid()::text, :site_id, 'OPEN', :ht, :title, NULL, :sev, :tags::jsonb,
                        ST_SetSRID(ST_MakePoint(:lng, :lat),4326)::geography, :created_by, now(), now())
            """), {"site_id": row["site_id"], "ht": hz_type, "title": title, "sev": sev, "tags": str(tags).replace("'", '"'), "lng": lng, "lat": lat, "created_by": row["uploader_id"]})

        db.commit()
        return {"ok": True, "detections": len(dets)}
    except Exception as e:
        db.rollback()
        return {"ok": False, "error": str(e)}
    finally:
        db.close()


from celery.schedules import crontab

celery_app.conf.beat_schedule = {
    "cleanup-old-media-daily": {
        "task": "worker.cleanup_old_media",
        "schedule": crontab(hour=3, minute=0),
    }
}

@celery_app.task(name="worker.cleanup_old_media")
def cleanup_old_media(days: int = 7):
    db = SessionLocal()
    try:
        rows = db.execute(text("""
            SELECT id, s3_key
            FROM media
            WHERE created_at < (now() - (:days || ' days')::interval)
        """), {"days": days}).mappings().all()
        deleted = 0
        for r in rows:
            path = r["s3_key"]
            try:
                if path and os.path.exists(path):
                    os.remove(path)
            except Exception:
                pass
            db.execute(text("DELETE FROM media WHERE id = :id"), {"id": r["id"]})
            deleted += 1
        db.commit()
        return {"ok": True, "deleted": deleted}
    except Exception as e:
        db.rollback()
        return {"ok": False, "error": str(e)}
    finally:
        db.close()
