import os
import sys
import glob
import torch
import argparse
import shutil
import open_clip
from PIL import Image, ImageTk
import tkinter as tk
from tkinter import messagebox
from transformers import AutoProcessor, SiglipModel, AutoTokenizer
from concurrent.futures import ThreadPoolExecutor
import queue
import threading

# created by https://m14w.com/

CLIP_MODEL_NAME_H14 = 'ViT-H-14'
CLIP_PRETRAINED_H14 = 'laion2b_s32b_b79k'
CLIP_CACHE_FILE_H14 = 'clip_h14_embeddings_cache.pt'

CLIP_MODEL_NAME_B32 = 'ViT-B-32'
CLIP_PRETRAINED_B32 = 'laion2b_s34b_b79k'
CLIP_CACHE_FILE_B32 = 'clip_b32_embeddings_cache.pt'

SIGLIP_MODEL_NAME = 'google/siglip-so400m-patch14-384'
SIGLIP_CACHE_FILE = 'siglip_embeddings_cache.pt'

DISPLAY_N_DEFAULT = 20         
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
VALID_EXTS = ('.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff')

BATCH_INDEXING_THRESHOLD = 10000

PRUNE_SUGGEST_THRESHOLD = 500

IMG_BATCH_SIZE = 128
IMG_BATCH_SIZE_FAST = 256
NUM_LOAD_WORKERS = 4

VRAM_REQUIREMENTS = {
    'clip_h14': {'minimum': 2.5, 'recommended': 5.0, 'model_size': 2.5},
    'clip_b32': {'minimum': 0.5, 'recommended': 2.0, 'model_size': 0.4},
    'siglip': {'minimum': 1.6, 'recommended': 3.0, 'model_size': 1.6}
}

RAM_REQUIREMENTS = {
    'clip_h14': {'minimum': 8.0, 'recommended': 16.0},
    'clip_b32': {'minimum': 4.0, 'recommended': 8.0},
    'siglip': {'minimum': 6.0, 'recommended': 12.0}
}

class SystemResources:
    def __init__(self):
        self.vram_total = 0.0
        self.vram_free = 0.0
        self.ram_total = 0.0
        self.ram_free = 0.0
        self.cuda_available = torch.cuda.is_available()
        self._detect()

    def _detect(self):
        if self.cuda_available:
            try:
                self.vram_total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
                self.vram_free = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)) / (1024**3)
            except Exception:
                self.vram_total = 0.0
                self.vram_free = 0.0
        
        try:
            import psutil
            mem = psutil.virtual_memory()
            self.ram_total = mem.total / (1024**3)
            self.ram_free = mem.available / (1024**3)
        except ImportError:
            self._detect_ram_fallback()

    def _detect_ram_fallback(self):
        try:
            with open('/proc/meminfo', 'r') as f:
                lines = f.readlines()
            for line in lines:
                if line.startswith('MemTotal:'):
                    self.ram_total = int(line.split()[1]) / (1024**2)
                elif line.startswith('MemAvailable:'):
                    self.ram_free = int(line.split()[1]) / (1024**2)
        except Exception:
            self.ram_total = 8.0
            self.ram_free = 4.0

    def print_info(self):
        print("\n=== SYSTEM RESOURCES ===")
        print(f"RAM: {self.ram_free:.1f} GB free / {self.ram_total:.1f} GB total")
        if self.cuda_available:
            print(f"VRAM: {self.vram_free:.1f} GB free / {self.vram_total:.1f} GB total")
            print(f"GPU: {torch.cuda.get_device_name(0)}")
        else:
            print("VRAM: N/A (CUDA not available, using CPU)")
        print("========================\n")

    def refresh(self):
        self._detect()


EMBEDDING_MEMORY_PER_200K = 1.0
MEMORY_WARNING_THRESHOLD = 0.6


def estimate_embedding_memory_gb(num_images):
    return (num_images / 200000) * EMBEDDING_MEMORY_PER_200K


def estimate_cache_file_size_gb(filename):
    if not os.path.exists(filename):
        return 0.0
    try:
        size_bytes = os.path.getsize(filename)
        return size_bytes / (1024**3)
    except Exception:
        return 0.0


def check_memory_for_cache(cache_file, resources, device):
    cache_size_gb = estimate_cache_file_size_gb(cache_file)
    if cache_size_gb == 0:
        return True, []
    
    warnings = []
    should_warn = False
    
    loaded_memory_estimate = cache_size_gb * 2.5
    
    if device == "cuda" and resources.cuda_available:
        if loaded_memory_estimate > resources.vram_free * MEMORY_WARNING_THRESHOLD:
            should_warn = True
            warnings.append(
                f"Cache file ({cache_size_gb:.2f} GB on disk) may use ~{loaded_memory_estimate:.2f} GB VRAM when loaded.\n"
                f"  This exceeds {MEMORY_WARNING_THRESHOLD*100:.0f}% of free VRAM ({resources.vram_free:.2f} GB)."
            )
    
    if loaded_memory_estimate > resources.ram_free * MEMORY_WARNING_THRESHOLD:
        should_warn = True
        warnings.append(
            f"Cache file ({cache_size_gb:.2f} GB on disk) may use ~{loaded_memory_estimate:.2f} GB RAM when loaded.\n"
            f"  This exceeds {MEMORY_WARNING_THRESHOLD*100:.0f}% of free RAM ({resources.ram_free:.2f} GB)."
        )
    
    return not should_warn, warnings


def check_memory_for_embeddings(num_total_images, num_to_compute, resources, device, model_key):
    warnings = []
    
    total_embedding_memory = estimate_embedding_memory_gb(num_total_images)
    new_embedding_memory = estimate_embedding_memory_gb(num_to_compute)
    
    model_vram = VRAM_REQUIREMENTS[model_key]['model_size']
    
    if device == "cuda" and resources.cuda_available:
        total_vram_needed = model_vram + total_embedding_memory
        vram_after_model = resources.vram_free - model_vram
        
        if total_vram_needed > resources.vram_free:
            warnings.append(
                f"VRAM SATURATION RISK: {num_total_images:,} images need ~{total_embedding_memory:.2f} GB for embeddings.\n"
                f"  Model needs ~{model_vram:.1f} GB. Total: ~{total_vram_needed:.2f} GB, but only {resources.vram_free:.2f} GB free.\n"
                f"  Embeddings will overflow to RAM (slower searches)."
            )
        elif total_embedding_memory > vram_after_model * MEMORY_WARNING_THRESHOLD:
            warnings.append(
                f"VRAM WARNING: {num_total_images:,} images need ~{total_embedding_memory:.2f} GB for embeddings.\n"
                f"  After loading model (~{model_vram:.1f} GB), only ~{vram_after_model:.2f} GB VRAM remains.\n"
                f"  This exceeds {MEMORY_WARNING_THRESHOLD*100:.0f}% threshold."
            )
    
    total_ram_needed = total_embedding_memory
    if total_ram_needed > resources.ram_free * MEMORY_WARNING_THRESHOLD:
        warnings.append(
            f"RAM WARNING: {num_total_images:,} images need ~{total_embedding_memory:.2f} GB for embeddings.\n"
            f"  This exceeds {MEMORY_WARNING_THRESHOLD*100:.0f}% of free RAM ({resources.ram_free:.2f} GB).\n"
            f"  System may become unresponsive or use swap."
        )
    
    if num_to_compute > 0:
        compute_memory = new_embedding_memory + 0.5
        if device == "cuda" and resources.cuda_available:
            if compute_memory > resources.vram_free - model_vram:
                warnings.append(
                    f"COMPUTE WARNING: {num_to_compute:,} new images to embed need ~{new_embedding_memory:.2f} GB.\n"
                    f"  Processing will require careful memory management."
                )
    
    return warnings


def prompt_memory_warning(warnings):
    if not warnings:
        return True
    
    print("\n" + "!"*60)
    print("!!! MEMORY WARNINGS !!!")
    print("!"*60)
    for w in warnings:
        print(f"\n{w}")
    print("\n" + "!"*60)
    
    choice = input("\nDo you want to continue anyway? (y/n): ").strip().lower()
    return choice == 'y'


def get_model_key(use_siglip, use_b32):
    if use_siglip:
        return 'siglip'
    elif use_b32:
        return 'clip_b32'
    else:
        return 'clip_h14'


def optimize_settings(resources, use_siglip=False, use_b32=False, user_fast_mode=False):
    model_key = get_model_key(use_siglip, use_b32)
    vram_req = VRAM_REQUIREMENTS[model_key]
    ram_req = RAM_REQUIREMENTS[model_key]
    
    settings = {
        'batch_size': IMG_BATCH_SIZE,
        'batch_size_fast': IMG_BATCH_SIZE_FAST,
        'num_workers': NUM_LOAD_WORKERS,
        'fast_mode': user_fast_mode,
        'use_cpu': False,
        'warnings': [],
        'can_run': True
    }
    
    if not resources.cuda_available:
        settings['use_cpu'] = True
        settings['fast_mode'] = False
        settings['batch_size'] = 16
        settings['batch_size_fast'] = 16
        settings['num_workers'] = 2
        settings['warnings'].append(f"No CUDA GPU detected. Running on CPU (significantly slower).")
        
        if resources.ram_total < ram_req['minimum']:
            settings['can_run'] = False
            settings['warnings'].append(
                f"INSUFFICIENT RAM: {resources.ram_total:.1f} GB detected, "
                f"minimum {ram_req['minimum']:.1f} GB required for {model_key.upper()}"
            )
        elif resources.ram_total < ram_req['recommended']:
            settings['batch_size'] = 8
            settings['num_workers'] = 1
            settings['warnings'].append(
                f"Low RAM: {resources.ram_total:.1f} GB (recommended: {ram_req['recommended']:.1f} GB). "
                f"Reduced batch size to {settings['batch_size']}"
            )
        return settings
    
    if resources.vram_total < vram_req['minimum']:
        settings['can_run'] = False
        settings['warnings'].append(
            f"INSUFFICIENT VRAM: {resources.vram_total:.1f} GB detected, "
            f"minimum {vram_req['minimum']:.1f} GB required for {model_key.upper()}"
        )
        return settings
    
    if resources.vram_total < vram_req['recommended']:
        settings['fast_mode'] = True
        settings['batch_size'] = 32
        settings['batch_size_fast'] = 64
        settings['warnings'].append(
            f"Limited VRAM: {resources.vram_total:.1f} GB (recommended: {vram_req['recommended']:.1f} GB). "
            f"Enabled float16, reduced batch size to {settings['batch_size_fast'] if user_fast_mode else settings['batch_size']}"
        )
    elif resources.vram_total >= 8.0:
        if user_fast_mode:
            settings['batch_size_fast'] = 512
        else:
            settings['batch_size'] = 256
    elif resources.vram_total >= 12.0:
        if user_fast_mode:
            settings['batch_size_fast'] = 768
        else:
            settings['batch_size'] = 384
    
    if resources.ram_total < ram_req['minimum']:
        settings['warnings'].append(
            f"Low RAM: {resources.ram_total:.1f} GB (minimum: {ram_req['minimum']:.1f} GB). "
            f"May experience slowdowns due to swapping"
        )
        settings['num_workers'] = 2
    elif resources.ram_total >= 32.0:
        settings['num_workers'] = 8
    elif resources.ram_total >= 16.0:
        settings['num_workers'] = 6
    
    return settings


def check_resources_and_confirm(resources, use_siglip=False, use_b32=False, user_fast_mode=False):
    settings = optimize_settings(resources, use_siglip, use_b32, user_fast_mode)
    
    if settings['warnings']:
        print("\n!!! RESOURCE WARNINGS !!!")
        for w in settings['warnings']:
            print(f"  - {w}")
        print()
    
    if not settings['can_run']:
        print("The system does not meet minimum requirements to run this model.")
        choice = input("Do you want to continue anyway? This may crash or be extremely slow. (y/n): ").strip().lower()
        if choice != 'y':
            return None
        settings['can_run'] = True
        settings['batch_size'] = 8
        settings['batch_size_fast'] = 8
        settings['num_workers'] = 1
    
    return settings


def load_persistent_cache(filename, resources=None, device="cuda"):
    if os.path.exists(filename):
        print(f"[{filename}] Found cache file. Loading...")
        
        if resources is not None:
            ok, warnings = check_memory_for_cache(filename, resources, device)
            if not ok:
                if not prompt_memory_warning(warnings):
                    print("Skipping cache load. Starting fresh.")
                    return {}
        
        try:
            return torch.load(filename, map_location="cpu")
        except Exception as e:
            print(f"Warning: Cache corrupt ({e}). Starting fresh.")
            return {}
    return {}

def save_persistent_cache(cache_data, filename):
    print(f"[{filename}] Saving cache update...")
    torch.save(cache_data, filename)

def prune_cache(cache_data, filename):
    print(f"[{filename}] Pruning cache...")
    original_count = len(cache_data)
    keys_to_remove = []
    
    for i, path in enumerate(cache_data.keys()):
        if i % 1000 == 0:
            sys.stdout.write(f"\rChecking {i}/{original_count}...")
            sys.stdout.flush()
        try:
            mtime = os.path.getmtime(path)
            if cache_data[path]['mtime'] != mtime:
                keys_to_remove.append(path)
        except OSError:
            keys_to_remove.append(path)
    
    for key in keys_to_remove:
        del cache_data[key]
    
    removed_count = len(keys_to_remove)
    print(f"\rPruned {removed_count} stale entries. Cache: {original_count} -> {len(cache_data)}")
    
    if removed_count > 0:
        save_persistent_cache(cache_data, filename)
    
    return cache_data, removed_count


def load_model_clip(fast_mode=False, device=None, use_b32=False):
    device = device or DEVICE
    if use_b32:
        model_name = CLIP_MODEL_NAME_B32
        pretrained = CLIP_PRETRAINED_B32
    else:
        model_name = CLIP_MODEL_NAME_H14
        pretrained = CLIP_PRETRAINED_H14
    
    print(f"Loading CLIP: {model_name} ({pretrained})...")
    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained, device=device)
    tokenizer = open_clip.get_tokenizer(model_name)
    model.eval()
    
    if fast_mode and device == "cuda":
        print("Enabling fast mode: float16...")
        model = model.half()
        try:
            print("Attempting torch.compile()...")
            model = torch.compile(model, mode="reduce-overhead")
            print("torch.compile() enabled.")
        except Exception as e:
            print(f"torch.compile() not available: {e}")
    
    return model, preprocess, tokenizer

def load_model_siglip(fast_mode=False, device=None):
    device = device or DEVICE
    print(f"Loading SigLIP: {SIGLIP_MODEL_NAME}...")
    
    if fast_mode and device == "cuda":
        model = SiglipModel.from_pretrained(SIGLIP_MODEL_NAME, torch_dtype=torch.float16).to(device)
    else:
        model = SiglipModel.from_pretrained(SIGLIP_MODEL_NAME).to(device)
    
    processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_NAME)
    tokenizer = AutoTokenizer.from_pretrained(SIGLIP_MODEL_NAME)
    model.eval()
    
    if fast_mode and device == "cuda":
        print("Fast mode: float16 enabled.")
        try:
            print("Attempting torch.compile()...")
            model = torch.compile(model, mode="reduce-overhead")
            print("torch.compile() enabled.")
        except Exception as e:
            print(f"torch.compile() not available: {e}")
    
    return model, processor, tokenizer


class ConceptViewer:
    def __init__(self, root, image_data, concept_text):
        self.root = root
        self.image_data = image_data 
        self.concept = concept_text
        self.index = 0
        self.zoom_mode = False 
        
        self.root.title(f"Concept: '{self.concept}'")
        self.root.attributes('-fullscreen', True)
        self.root.configure(bg='black')
        
        self.frame_main = tk.Frame(self.root, bg='black')
        self.frame_main.pack(fill='both', expand=True)
        self.frame_main.grid_rowconfigure(0, weight=1)
        self.frame_main.grid_columnconfigure(0, weight=1)

        self.v_scroll = tk.Scrollbar(self.frame_main, orient='vertical')
        self.h_scroll = tk.Scrollbar(self.frame_main, orient='horizontal')

        self.canvas = tk.Canvas(self.frame_main, bg='black', highlightthickness=0,
                                yscrollcommand=self.v_scroll.set, xscrollcommand=self.h_scroll.set)
        
        self.v_scroll.config(command=self.canvas.yview)
        self.h_scroll.config(command=self.canvas.xview)
        self.canvas.grid(row=0, column=0, sticky="nsew")

        self.lbl_info = tk.Label(self.root, text="", fg="#00FF00", bg="black", font=("Arial", 14, "bold"))
        self.lbl_info.place(x=20, y=20) 
        
        self.lbl_help = tk.Label(self.root, text="ESC to Quit | Arrows to Navigate | Double Click to Zoom", fg="white", bg="black", font=("Arial", 10))
        self.lbl_help.place(x=20, y=60)

        self.root.bind("<Escape>", lambda e: self.root.destroy())
        self.root.bind("<Left>", self.prev_image)
        self.root.bind("<Right>", self.next_image)
        self.root.bind("<MouseWheel>", self._on_mousewheel)
        self.root.bind("<Button-4>", self._on_mousewheel) 
        self.root.bind("<Button-5>", self._on_mousewheel) 
        self.canvas.bind("<Double-Button-1>", self.toggle_view)

        self.load_current_image()

    def _on_mousewheel(self, event):
        if self.zoom_mode:
            if event.num == 5 or event.delta < 0:
                self.canvas.yview_scroll(1, "units")
            elif event.num == 4 or event.delta > 0:
                self.canvas.yview_scroll(-1, "units")

    def load_current_image(self):
        score, path = self.image_data[self.index]
        self.lbl_info.config(text=f"#{self.index + 1}/{len(self.image_data)} | Score: {score:.2f} | {os.path.basename(path)}")
        
        try:
            self.pil_image = Image.open(path)
            self.zoom_mode = False 
            self.show_fit()
        except Exception as e:
            print(f"Error loading {path}: {e}")
            self.next_image()

    def toggle_view(self, event=None):
        self.zoom_mode = not self.zoom_mode
        if self.zoom_mode: self.show_actual_size()
        else: self.show_fit()

    def show_fit(self):
        self.v_scroll.grid_remove()
        self.h_scroll.grid_remove()
        screen_w = self.root.winfo_screenwidth()
        screen_h = self.root.winfo_screenheight()
        img_w, img_h = self.pil_image.size
        
        ratio = min(screen_w / img_w, screen_h / img_h)
        new_w = int(img_w * ratio)
        new_h = int(img_h * ratio)
        
        resized = self.pil_image.resize((new_w, new_h), Image.Resampling.LANCZOS)
        self.tk_image = ImageTk.PhotoImage(resized)
        
        self.canvas.delete("all")
        self.canvas.create_image(screen_w // 2, screen_h // 2, image=self.tk_image, anchor='center')
        self.canvas.config(scrollregion=(0, 0, screen_w, screen_h))

    def show_actual_size(self):
        self.v_scroll.grid(row=0, column=1, sticky="ns")
        self.h_scroll.grid(row=1, column=0, sticky="ew")
        self.tk_image = ImageTk.PhotoImage(self.pil_image)
        self.canvas.delete("all")
        self.canvas.create_image(0, 0, image=self.tk_image, anchor='nw')
        self.canvas.config(scrollregion=self.canvas.bbox("all"))

    def next_image(self, event=None):
        if self.index < len(self.image_data) - 1:
            self.index += 1
            self.load_current_image()
        else:
            if messagebox.askyesno("End", "End of list. Exit viewer?"):
                self.root.destroy()

    def prev_image(self, event=None):
        if self.index > 0:
            self.index -= 1
            self.load_current_image()


def load_single_image_clip(args):
    path, mtime, preprocess = args
    try:
        img = preprocess(Image.open(path))
        return (path, mtime, img)
    except:
        return None

def load_single_image_siglip(args):
    path, mtime = args
    try:
        img = Image.open(path).convert("RGB")
        return (path, mtime, img)
    except:
        return None

def prefetch_batches_clip(paths_to_compute, preprocess, batch_size, result_queue, stop_event, num_workers):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        for i in range(0, len(paths_to_compute), batch_size):
            if stop_event.is_set():
                break
            batch_slice = paths_to_compute[i:i+batch_size]
            args_list = [(p, m, preprocess) for p, m in batch_slice]
            results = list(executor.map(load_single_image_clip, args_list))
            valid = [r for r in results if r is not None]
            result_queue.put((i, valid))
    result_queue.put(None)

def prefetch_batches_siglip(paths_to_compute, batch_size, result_queue, stop_event, num_workers):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        for i in range(0, len(paths_to_compute), batch_size):
            if stop_event.is_set():
                break
            batch_slice = paths_to_compute[i:i+batch_size]
            args_list = [(p, m) for p, m in batch_slice]
            results = list(executor.map(load_single_image_siglip, args_list))
            valid = [r for r in results if r is not None]
            result_queue.put((i, valid))
    result_queue.put(None)


def get_files(directory, recursive=False):
    files = []
    print(f"Scanning directory: {directory} (Recursive: {recursive})")
    if recursive:
        for root, _, filenames in os.walk(directory):
            for filename in filenames:
                if filename.lower().endswith(VALID_EXTS):
                    files.append(os.path.join(root, filename))
    else:
        for ext in VALID_EXTS:
            files.extend(glob.glob(os.path.join(directory, f"*{ext}")))
    return files

def get_paths_to_compute(files, cache_dict):
    paths_to_compute = []
    cached_paths = []
    cached_embeddings = []
    for path in files:
        try:
            mtime = os.path.getmtime(path)
            if path in cache_dict and cache_dict[path]['mtime'] == mtime:
                cached_paths.append(path)
                cached_embeddings.append(cache_dict[path]['embedding'])
            else:
                paths_to_compute.append((path, mtime))
        except OSError:
            continue
    return paths_to_compute, cached_paths, cached_embeddings

def compute_embeddings_clip(files, model, preprocess, cache_dict, cache_file, settings, device):
    embeddings_list = []
    valid_paths = []
    batch_size = settings['batch_size_fast'] if settings['fast_mode'] else settings['batch_size']
    num_workers = settings['num_workers']
    
    print("Checking cache...")
    paths_to_compute, cached_paths, cached_embeddings = get_paths_to_compute(files, cache_dict)
    valid_paths.extend(cached_paths)
    embeddings_list.extend(cached_embeddings)
    print(f"Used {len(cached_paths)} cached embeddings.")
    
    num_compute = len(paths_to_compute)
    if num_compute > 0:
        print(f"Computing {num_compute} new images (CLIP)... [batch_size={batch_size}, workers={num_workers}]")
        
        result_queue = queue.Queue(maxsize=2)
        stop_event = threading.Event()
        loader_thread = threading.Thread(target=prefetch_batches_clip, args=(paths_to_compute, preprocess, batch_size, result_queue, stop_event, num_workers))
        loader_thread.start()
        
        try:
            while True:
                item = result_queue.get()
                if item is None:
                    break
                i, batch_data = item
                if not batch_data:
                    continue
                
                sys.stdout.write(f"\rProcessing {i}/{num_compute}...")
                sys.stdout.flush()
                
                batch_tensors = [d[2] for d in batch_data]
                batch_meta = [(d[0], d[1]) for d in batch_data]
                
                batch_input = torch.stack(batch_tensors).to(device)
                if settings['fast_mode'] and device == "cuda":
                    batch_input = batch_input.half()
                
                with torch.inference_mode():
                    batch_emb = model.encode_image(batch_input)
                    batch_emb /= batch_emb.norm(dim=-1, keepdim=True)
                    batch_emb = batch_emb.float().cpu()
                
                for idx, (path, mtime) in enumerate(batch_meta):
                    emb = batch_emb[idx]
                    cache_dict[path] = {'mtime': mtime, 'embedding': emb}
                    valid_paths.append(path)
                    embeddings_list.append(emb)
        finally:
            stop_event.set()
            loader_thread.join()

        print("\nNew computations done.")
        save_persistent_cache(cache_dict, cache_file)

    if not embeddings_list: return None, [], cache_dict
    return torch.stack(embeddings_list).to(device), valid_paths, cache_dict

def compute_embeddings_clip_batch(paths_to_compute, model, preprocess, cache_dict, cache_file, packet_index, total_packets, settings, device):
    num_compute = len(paths_to_compute)
    if num_compute == 0:
        return cache_dict
    
    batch_size = settings['batch_size_fast'] if settings['fast_mode'] else settings['batch_size']
    num_workers = settings['num_workers']
    print(f"\n[Packet {packet_index}/{total_packets}] Computing {num_compute} images (CLIP)... [batch_size={batch_size}]")
    
    result_queue = queue.Queue(maxsize=2)
    stop_event = threading.Event()
    loader_thread = threading.Thread(target=prefetch_batches_clip, args=(paths_to_compute, preprocess, batch_size, result_queue, stop_event, num_workers))
    loader_thread.start()
    
    try:
        while True:
            item = result_queue.get()
            if item is None:
                break
            i, batch_data = item
            if not batch_data:
                continue
            
            remaining = num_compute - i
            sys.stdout.write(f"\r[Packet {packet_index}/{total_packets}] Processing {i}/{num_compute} | Remaining: {remaining}...")
            sys.stdout.flush()
            
            batch_tensors = [d[2] for d in batch_data]
            batch_meta = [(d[0], d[1]) for d in batch_data]
            
            batch_input = torch.stack(batch_tensors).to(device)
            if settings['fast_mode'] and device == "cuda":
                batch_input = batch_input.half()
            
            with torch.inference_mode():
                batch_emb = model.encode_image(batch_input)
                batch_emb /= batch_emb.norm(dim=-1, keepdim=True)
                batch_emb = batch_emb.float().cpu()
            
            for idx, (path, mtime) in enumerate(batch_meta):
                emb = batch_emb[idx]
                cache_dict[path] = {'mtime': mtime, 'embedding': emb}
    finally:
        stop_event.set()
        loader_thread.join()

    print(f"\n[Packet {packet_index}/{total_packets}] Done. Saving cache...")
    save_persistent_cache(cache_dict, cache_file)
    return cache_dict

def compute_embeddings_siglip(files, model, processor, cache_dict, cache_file, settings, device):
    embeddings_list = []
    valid_paths = []
    batch_size = settings['batch_size_fast'] if settings['fast_mode'] else settings['batch_size']
    num_workers = settings['num_workers']
    
    print("Checking cache...")
    paths_to_compute, cached_paths, cached_embeddings = get_paths_to_compute(files, cache_dict)
    valid_paths.extend(cached_paths)
    embeddings_list.extend(cached_embeddings)
    print(f"Used {len(cached_paths)} cached embeddings.")

    num_compute = len(paths_to_compute)
    if num_compute > 0:
        print(f"Computing {num_compute} new images (SigLIP)... [batch_size={batch_size}, workers={num_workers}]")

        result_queue = queue.Queue(maxsize=2)
        stop_event = threading.Event()
        loader_thread = threading.Thread(target=prefetch_batches_siglip, args=(paths_to_compute, batch_size, result_queue, stop_event, num_workers))
        loader_thread.start()
        
        try:
            while True:
                item = result_queue.get()
                if item is None:
                    break
                i, batch_data = item
                if not batch_data:
                    continue
                
                sys.stdout.write(f"\rProcessing {i}/{num_compute}...")
                sys.stdout.flush()
                
                batch_images = [d[2] for d in batch_data]
                batch_meta = [(d[0], d[1]) for d in batch_data]
                
                inputs = processor(images=batch_images, return_tensors="pt").to(device)
                if settings['fast_mode'] and device == "cuda":
                    inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
                
                with torch.inference_mode():
                    batch_emb = model.get_image_features(**inputs)
                    batch_emb /= batch_emb.norm(dim=-1, keepdim=True)
                    batch_emb = batch_emb.float().cpu()
                
                for idx, (path, mtime) in enumerate(batch_meta):
                    emb = batch_emb[idx]
                    cache_dict[path] = {'mtime': mtime, 'embedding': emb}
                    valid_paths.append(path)
                    embeddings_list.append(emb)
        finally:
            stop_event.set()
            loader_thread.join()

        print("\nNew computations done.")
        save_persistent_cache(cache_dict, cache_file)

    if not embeddings_list: return None, [], cache_dict
    return torch.stack(embeddings_list).to(device), valid_paths, cache_dict

def compute_embeddings_siglip_batch(paths_to_compute, model, processor, cache_dict, cache_file, packet_index, total_packets, settings, device):
    num_compute = len(paths_to_compute)
    if num_compute == 0:
        return cache_dict
    
    batch_size = settings['batch_size_fast'] if settings['fast_mode'] else settings['batch_size']
    num_workers = settings['num_workers']
    print(f"\n[Packet {packet_index}/{total_packets}] Computing {num_compute} images (SigLIP)... [batch_size={batch_size}]")

    result_queue = queue.Queue(maxsize=2)
    stop_event = threading.Event()
    loader_thread = threading.Thread(target=prefetch_batches_siglip, args=(paths_to_compute, batch_size, result_queue, stop_event, num_workers))
    loader_thread.start()
    
    try:
        while True:
            item = result_queue.get()
            if item is None:
                break
            i, batch_data = item
            if not batch_data:
                continue
            
            remaining = num_compute - i
            sys.stdout.write(f"\r[Packet {packet_index}/{total_packets}] Processing {i}/{num_compute} | Remaining: {remaining}...")
            sys.stdout.flush()
            
            batch_images = [d[2] for d in batch_data]
            batch_meta = [(d[0], d[1]) for d in batch_data]
            
            inputs = processor(images=batch_images, return_tensors="pt").to(device)
            if settings['fast_mode'] and device == "cuda":
                inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
            
            with torch.inference_mode():
                batch_emb = model.get_image_features(**inputs)
                batch_emb /= batch_emb.norm(dim=-1, keepdim=True)
                batch_emb = batch_emb.float().cpu()
            
            for idx, (path, mtime) in enumerate(batch_meta):
                emb = batch_emb[idx]
                cache_dict[path] = {'mtime': mtime, 'embedding': emb}
    finally:
        stop_event.set()
        loader_thread.join()

    print(f"\n[Packet {packet_index}/{total_packets}] Done. Saving cache...")
    save_persistent_cache(cache_dict, cache_file)
    return cache_dict

def load_embeddings_from_cache(files, cache_dict, device):
    embeddings_list = []
    valid_paths = []
    for path in files:
        try:
            mtime = os.path.getmtime(path)
            if path in cache_dict and cache_dict[path]['mtime'] == mtime:
                valid_paths.append(path)
                embeddings_list.append(cache_dict[path]['embedding'])
        except OSError:
            continue
    if not embeddings_list:
        return None, []
    return torch.stack(embeddings_list).to(device), valid_paths


def main():
    parser = argparse.ArgumentParser(description="Find images matching a concept.")
    parser.add_argument("--sub", action="store_true", help="Look inside subdirectories")
    parser.add_argument("--sigclip", action="store_true", help="Use SigLIP instead of CLIP")
    parser.add_argument("--b32", action="store_true", help="Use ViT-B/32 (lighter) instead of ViT-H/14")
    parser.add_argument("--copy", action="store_true", help="Copy matching images to a folder")
    parser.add_argument("--move", action="store_true", help="Move matching images to a folder")
    parser.add_argument("-s", "--save", type=str, help="Export processed data to a specific file (.pt)")
    parser.add_argument("--fast", action="store_true", help="Enable fast mode: float16 + torch.compile + larger batches")
    parser.add_argument("--prune", action="store_true", help="Remove stale entries from cache before processing")
    parser.add_argument("--skip-check", action="store_true", help="Skip system resource check")
    parser.add_argument("-n", "--num", type=int, default=20, help="Number of top results to display (default: 20)")
    args = parser.parse_args()
    
    display_n = args.num

    if args.sigclip:
        mode_name = "SigLIP"
    elif args.b32:
        mode_name = "OpenCLIP ViT-B/32"
    else:
        mode_name = "OpenCLIP ViT-H/14"
    
    resources = SystemResources()
    resources.print_info()
    
    if not args.skip_check:
        settings = check_resources_and_confirm(resources, args.sigclip, args.b32, args.fast)
        if settings is None:
            print("Exiting.")
            return
    else:
        settings = optimize_settings(resources, args.sigclip, args.b32, args.fast)
    
    device = "cpu" if settings['use_cpu'] else DEVICE
    
    fast_str = " [FAST MODE]" if settings['fast_mode'] else ""
    cpu_str = " [CPU]" if settings['use_cpu'] else ""
    print(f"\n--- {mode_name} CONCEPT VIEWER ({device}){fast_str}{cpu_str} ---")
    print(f"Optimized settings: batch_size={settings['batch_size_fast'] if settings['fast_mode'] else settings['batch_size']}, workers={settings['num_workers']}")

    root_dir = input("\nEnter the directory path to scan: ").strip().strip('"').strip("'")
    if not os.path.exists(root_dir):
        print("Directory not found.")
        return

    if args.sigclip:
        current_cache_file = SIGLIP_CACHE_FILE
    elif args.b32:
        current_cache_file = CLIP_CACHE_FILE_B32
    else:
        current_cache_file = CLIP_CACHE_FILE_H14
    
    resources.refresh()
    cache_data = load_persistent_cache(current_cache_file, resources, device)

    if args.prune and cache_data:
        cache_data, _ = prune_cache(cache_data, current_cache_file)

    files = get_files(root_dir, recursive=args.sub)
    if not files:
        print("No images found.")
        return

    print(f"Found {len(files)} files.")

    print("Analyzing cache status...")
    paths_to_compute, cached_paths, _ = get_paths_to_compute(files, cache_data)
    num_to_index = len(paths_to_compute)
    num_cached = len(cached_paths)
    print(f"Cached: {num_cached} | Need indexing: {num_to_index}")

    model_key = get_model_key(args.sigclip, args.b32)
    resources.refresh()
    memory_warnings = check_memory_for_embeddings(
        num_total_images=len(files),
        num_to_compute=num_to_index,
        resources=resources,
        device=device,
        model_key=model_key
    )
    if memory_warnings:
        if not prompt_memory_warning(memory_warnings):
            print("Exiting due to memory concerns.")
            return

    if not args.prune and num_to_index > PRUNE_SUGGEST_THRESHOLD and cache_data:
        print(f"\n*** {num_to_index} new files to index (threshold: {PRUNE_SUGGEST_THRESHOLD}) ***")
        choice = input("Prune stale cache entries before indexing? (y/n): ").strip().lower()
        if choice == 'y':
            cache_data, _ = prune_cache(cache_data, current_cache_file)
            paths_to_compute, cached_paths, _ = get_paths_to_compute(files, cache_data)
            num_to_index = len(paths_to_compute)
            num_cached = len(cached_paths)
            print(f"After prune - Cached: {num_cached} | Need indexing: {num_to_index}")

    batch_mode = False
    if num_to_index > BATCH_INDEXING_THRESHOLD:
        print(f"\n*** Large dataset detected: {num_to_index} files need indexing (threshold: {BATCH_INDEXING_THRESHOLD}) ***")
        choice = input("Enter batch indexing mode? This saves progress every packet. (y/n): ").strip().lower()
        batch_mode = (choice == 'y')

    if batch_mode:
        total_packets = (num_to_index + BATCH_INDEXING_THRESHOLD - 1) // BATCH_INDEXING_THRESHOLD
        print(f"\nBatch indexing: {num_to_index} files in {total_packets} packet(s) of up to {BATCH_INDEXING_THRESHOLD} files each.")
        
        if args.sigclip:
            model, processor, tokenizer = load_model_siglip(fast_mode=settings['fast_mode'], device=device)
        else:
            model, preprocess, tokenizer = load_model_clip(fast_mode=settings['fast_mode'], device=device, use_b32=args.b32)
        
        for packet_idx in range(total_packets):
            start = packet_idx * BATCH_INDEXING_THRESHOLD
            end = min(start + BATCH_INDEXING_THRESHOLD, num_to_index)
            packet = paths_to_compute[start:end]
            
            remaining_total = num_to_index - start
            print(f"\n{'='*60}")
            print(f"Packet {packet_idx + 1}/{total_packets} | Files: {len(packet)} | Total remaining: {remaining_total}")
            print(f"{'='*60}")
            
            if args.sigclip:
                cache_data = compute_embeddings_siglip_batch(packet, model, processor, cache_data, current_cache_file, packet_idx + 1, total_packets, settings, device)
            else:
                cache_data = compute_embeddings_clip_batch(packet, model, preprocess, cache_data, current_cache_file, packet_idx + 1, total_packets, settings, device)
        
        print(f"\n{'='*60}")
        print("Batch indexing complete. Loading all embeddings from cache...")
        print(f"{'='*60}")
        image_features, valid_paths = load_embeddings_from_cache(files, cache_data, device)
    else:
        if args.sigclip:
            model, processor, tokenizer = load_model_siglip(fast_mode=settings['fast_mode'], device=device)
            image_features, valid_paths, cache_data = compute_embeddings_siglip(
                files, model, processor, cache_data, current_cache_file, settings, device
            )
        else:
            model, preprocess, tokenizer = load_model_clip(fast_mode=settings['fast_mode'], device=device, use_b32=args.b32)
            image_features, valid_paths, cache_data = compute_embeddings_clip(
                files, model, preprocess, cache_data, current_cache_file, settings, device
            )
    
    if image_features is None:
        print("No valid features extracted.")
        return

    if args.save:
        print(f"Exporting snapshot to '{args.save}'...")
        relative_paths = [os.path.relpath(p, start=root_dir) for p in valid_paths]
        save_payload = {'features': image_features.cpu(), 'paths': relative_paths}
        torch.save(save_payload, args.save)
        print("Export complete.")

    while True:
        concept = input("\nEnter the CONCEPT (or 'q' to quit): ").strip()
        if not concept or concept.lower() == 'q': break

        print(f"Encoding concept: '{concept}'...")
        
        with torch.inference_mode():
            if args.sigclip:
                inputs = tokenizer([concept], padding="max_length", truncation=True, return_tensors="pt").to(device)
                text_features = model.get_text_features(**inputs)
            else:
                text_token = tokenizer([concept]).to(device)
                text_features = model.encode_text(text_token)

            text_features /= text_features.norm(dim=-1, keepdim=True)
            
            similarity = (100.0 * image_features @ text_features.T).squeeze()
            
            results = []
            if similarity.ndim == 0:
                 results.append((float(similarity), valid_paths[0]))
            else:
                sim_cpu = similarity.cpu().numpy()
                for j, score in enumerate(sim_cpu):
                    results.append((float(score), valid_paths[j]))

        results.sort(key=lambda x: x[0], reverse=True)
        top_results = results[:display_n]
        
        if not top_results:
            print("No matches found.")
            continue
            
        print(f"Launching viewer with top {len(top_results)} results...")
        
        root = tk.Tk()
        viewer = ConceptViewer(root, top_results, concept)
        root.mainloop()
        
        if args.copy or args.move:
            action_name = "MOVE" if args.move else "COPY"
            print(f"\n--- {action_name} OPERATION ---")
            
            confirm = input(f"Do you want to {action_name} these {len(top_results)} specific images? (y/n): ").lower()
            if confirm == 'y':
                dest_dir = input(f"Enter destination directory for {action_name}: ").strip().strip('"').strip("'")
                
                if not dest_dir:
                    print("Skipping operation.")
                    continue
                    
                os.makedirs(dest_dir, exist_ok=True)
                
                for i, (score, src_path) in enumerate(top_results):
                    filename = os.path.basename(src_path)
                    dest_path = os.path.join(dest_dir, f"{i+1:03d}_{filename}")
                    
                    try:
                        if args.move:
                            shutil.move(src_path, dest_path)
                            print(f"Moved: {filename}")
                        else:
                            shutil.copy2(src_path, dest_path)
                            print(f"Copied: {filename}")
                    except Exception as e:
                        print(f"Error processing {filename}: {e}")
                
                print(f"{action_name} complete.")

if __name__ == "__main__":
    main()
