In [None]:
# ==================================================
# Synthetic Abstract Generation for Concept Evaluation
# ==================================================
import numpy as np
import pandas as pd
import json
import requests
import time
import random
import re  # Import the regex module
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from collections import defaultdict, Counter

# Set random seed for reproducibility
np.random.seed(4242)

# --- Configuration ---
# !! SECURITY WARNING !!: Avoid hardcoding API keys in scripts.
# Consider using environment variables or a secrets management tool.
# groq_api_key = os.environ.get("GROQ_API_KEY") or getpass.getpass("Enter your Groq API Key: ")
groq_api_key = ''

# --- Topic Space and Diversity Parameters ---
TOPICS = {
    "T1": {
        "name": "Machine Learning",
        "subtopics": ["Neural Networks", "Reinforcement Learning", "Supervised Learning", "Unsupervised Learning", "Transfer Learning"]
    },
    "T7": {
        "name": "Sustainable Development",
        "subtopics": ["Renewable Energy", "Climate Change Mitigation", "Resource Management", "Environmental Monitoring", "Sustainable Cities"]
    },
    "T8": {
        "name": "Behavioral Economics",
        "subtopics": ["Decision Making", "Cognitive Biases", "Risk Assessment", "Social Preferences", "Intertemporal Choice"]
    },
    "T9": {
        "name": "Digital Security",
        "subtopics": ["Cybersecurity", "Privacy Enhancing Technologies", "Authentication Methods", "Threat Detection", "Security Policy"]
    },
    "T10": {
        "name": "Public Health",
        "subtopics": ["Epidemiology", "Health Promotion", "Disease Prevention", "Health Equity", "Health Systems"]
    }
}

DOMAINS = ["Sports", "Marriage", "Childcare", "Exercise", "School", "Social Media", "Advertisement"]

DIVERSITY_PARAMS = {
    "methodological_approaches": [
        "Theory", "Qualitative", "Randomized Experiments", 
        "Quasi-Experimental", "Survey", "Correlational","Mixed-methods"
    ],
    "concept_granularity": [
        "General Principles", "Specific Applications", "Mixed"
    ],
    "interdisciplinary_orientation": [
        "Pure-discipline", "Multi-disciplinary"
    ],
    "rhetorical_structures": [
        "Problem-solution", "Contribution-focused",
        "Findings-centered", "Process-oriented"
    ],
    "terminology_density": [
        "Terminology-rich", "Balanced", "Minimal Jargon"
    ],
    "temporal_context": [
        "Contemporary", "Historical Context", "Future-oriented"
    ],
    "concept_blending": [
        "Separate", "Integrated"
    ]
}

# --- Helper Functions ---

def create_topic_network():
    """Create a weighted network of topics based on similarity/co-occurrence"""
    G = nx.Graph()
    for topic_id, topic_data in TOPICS.items():
        G.add_node(topic_id, name=topic_data["name"], subtopics=topic_data["subtopics"])

    for topic1 in TOPICS:
        for topic2 in TOPICS:
            if topic1 != topic2:
                subtopics1 = set(TOPICS[topic1]["subtopics"])
                subtopics2 = set(TOPICS[topic2]["subtopics"])
                jaccard = len(subtopics1.intersection(subtopics2)) / len(subtopics1.union(subtopics2))
                similarity = jaccard + np.random.normal(0, 0.1)
                similarity = max(0.05, min(0.95, similarity))
                if not G.has_edge(topic1, topic2): # Add edge only once
                    G.add_edge(topic1, topic2, weight=similarity)
    return G


def sample_concept_mix(topic_network, num_topics=None):
    """Sample a mix of topics from the topic network"""
    if num_topics is None:
        num_topics = np.random.choice([1, 2, 3], p=[0.2, 0.6, 0.2])

    all_topics = list(topic_network.nodes())

    if num_topics == 1:
        topic = np.random.choice(all_topics)
        return {topic: 1.0}
    else:
        selected_topics = [np.random.choice(all_topics)]
        for _ in range(num_topics - 1):
            all_neighbors = []
            neighbor_weights = []
            for t in selected_topics:
                for n in topic_network.neighbors(t):
                    # Ensure neighbor is not already selected and has weights
                    if n not in selected_topics and n in topic_network[t]:
                         # Check if neighbor exists and edge has weight data
                        if n not in all_neighbors:
                            all_neighbors.append(n)
                            neighbor_weights.append(topic_network[t][n]['weight'])
                        else:
                            # If neighbor already listed (from another selected topic),
                            # potentially average or sum weights? Let's just keep first found weight.
                            pass


            if not all_neighbors:
                remaining = list(set(all_topics) - set(selected_topics))
                if not remaining: break
                selected_topics.append(np.random.choice(remaining))
            else:
                total_weight = sum(neighbor_weights)
                if total_weight <= 0: # Handle cases where all weights are zero or negative
                     normalized_weights = [1/len(neighbor_weights)] * len(neighbor_weights) # Equal probability
                else:
                    normalized_weights = [w / total_weight for w in neighbor_weights]

                # Ensure lengths match before choice
                if len(all_neighbors) != len(normalized_weights):
                     print(f"Warning: Mismatch in neighbors ({len(all_neighbors)}) and weights ({len(normalized_weights)}). Using uniform distribution.")
                     selected_topics.append(np.random.choice(all_neighbors))
                else:
                    selected_topics.append(np.random.choice(all_neighbors, p=normalized_weights))

        if len(selected_topics) == 2:
            weights = [0.7, 0.3]
        elif len(selected_topics) == 3:
            weights = [0.5, 0.3, 0.2]
        else: # Handle cases where fewer than desired topics were found
             weights = [1.0 / len(selected_topics)] * len(selected_topics)

        # Ensure selected_topics and weights have the same length after sampling
        weights = weights[:len(selected_topics)]
        return {t: w for t, w in zip(selected_topics, weights)}

    
# --- Unified LLM Generation Function ---
def parse_llm_response(response_text: str) -> dict:
    """
    Attempts to parse the LLM response to extract a JSON object containing
    'title', 'abstract', and 'keywords'. Handles surrounding text,
    removes control characters, attempts to fix trailing commas, and validates structure.
    Uses print statements for messages instead of logging.

    Args:
        response_text: The raw string response from the LLM.

    Returns:
        A dictionary containing the successfully parsed and validated data
        (e.g., {'title': ..., 'abstract': ..., 'keywords': [...]})
        or an error dictionary {'error': ..., 'raw_response': ...} if parsing
        or validation fails.
    """
    if not response_text:
        print("ERROR: LLM response was empty.")
        return {"error": "Empty response", "raw_response": ""}

    # 1. Initial Cleaning: Strip whitespace
    response_text = response_text.strip()

    # 2. Remove control characters (e.g., ANSI color codes, non-printable chars)
    # This regex covers ASCII control characters and some common non-ASCII ones.
    response_text = re.sub(r'[\x00-\x1F\x7F-\x9F]', '', response_text)

    # 3. Remove potential markdown wrappers (```json ... ``` or ``` ... ```)
    # Handles optional 'json' language specifier and potential leading/trailing whitespace
    response_text = re.sub(r'^```(?:json)?\s*', '', response_text, flags=re.MULTILINE)
    response_text = re.sub(r'\s*```$', '', response_text, flags=re.MULTILINE)
    response_text = response_text.strip() # Strip again after removing wrappers

    # 4. Find the main JSON structure '{...}'
    start_index = response_text.find('{')
    end_index = response_text.rfind('}')

    if start_index == -1 or end_index == -1 or end_index < start_index:
        # If no clear '{...}' structure, print and return error
        error_msg = "Could not find valid JSON structure '{...}' in the response."
        print(f"ERROR: {error_msg} Raw Response Snippet: {response_text[:200]}")
        return {
            "error": error_msg,
            "raw_response": response_text[:500] # Return first 500 chars for context
        }

    # Extract the potential JSON string
    json_str = response_text[start_index : end_index + 1]

    # 5. Attempt to fix common JSON syntax issues (specifically trailing commas)
    # This regex removes commas just before a closing brace '}' or bracket ']'
    json_str_cleaned = re.sub(r',\s*([}\]])', r'\1', json_str)

    # 6. Attempt to parse the cleaned JSON string
    try:
        parsed_data = json.loads(json_str_cleaned)

        # Ensure it's a dictionary (JSON standard allows other root types)
        if not isinstance(parsed_data, dict):
             error_msg = f"Parsed JSON is not a dictionary (root type: {type(parsed_data).__name__})."
             print(f"ERROR: {error_msg} Parsed Data: {str(parsed_data)[:200]}")
             return {
                "error": error_msg,
                "parsed_data": parsed_data, # Return what was parsed
                "raw_response": response_text[:500]
             }

        # --- 7. Validation ---
        required_keys = ["title", "abstract", "keywords"]
        missing_keys = [key for key in required_keys if key not in parsed_data]

        if missing_keys:
            error_msg = f"Parsed JSON missing required keys: {', '.join(missing_keys)}."
            print(f"ERROR: {error_msg} Parsed Data Keys: {list(parsed_data.keys())}")
            return {
                "error": error_msg,
                "parsed_data": parsed_data, # Return partial data
                "raw_response": response_text[:500]
            }

        # Validate data types
        if not isinstance(parsed_data.get("title"), str):
             error_msg = f"Field 'title' is not a string (type: {type(parsed_data.get('title')).__name__})."
             print(f"ERROR: {error_msg}")
             return {
                "error": error_msg,
                "parsed_data": parsed_data,
                "raw_response": response_text[:500]
             }
        if not isinstance(parsed_data.get("abstract"), str):
             error_msg = f"Field 'abstract' is not a string (type: {type(parsed_data.get('abstract')).__name__})."
             print(f"ERROR: {error_msg}")
             return {
                "error": error_msg,
                "parsed_data": parsed_data,
                "raw_response": response_text[:500]
             }
        if not isinstance(parsed_data.get("keywords"), list):
             error_msg = f"Field 'keywords' is not a list (type: {type(parsed_data.get('keywords')).__name__})."
             print(f"ERROR: {error_msg}")
             # Optionally, try to convert if it's a comma-separated string, but this is less robust
             # keywords_val = parsed_data.get("keywords")
             # if isinstance(keywords_val, str):
             #     print("WARNING: Attempting to split string keywords into list.")
             #     parsed_data["keywords"] = [k.strip() for k in keywords_val.split(',') if k.strip()]
             # else: # If not string or list, return error
             return {
                "error": error_msg,
                "parsed_data": parsed_data,
                "raw_response": response_text[:500]
             }
        # Optional: Check if all items in keywords list are strings
        if not all(isinstance(item, str) for item in parsed_data.get("keywords", [])):
             error_msg = "Not all items in 'keywords' list are strings."
             print(f"ERROR: {error_msg}")
             return {
                "error": error_msg,
                "parsed_data": parsed_data,
                "raw_response": response_text[:500]
             }


        # If all checks pass, return the validated data
        print("INFO: Successfully parsed and validated LLM response.")
        return parsed_data

    except json.JSONDecodeError as e:
        # If parsing fails even after cleaning, provide detailed error
        error_msg = f"JSONDecodeError: {e}. Issue likely near char {e.pos} in the cleaned JSON snippet."

        # Try to show the problematic part of the string for context
        context_chars = 40
        start = max(0, e.pos - context_chars)
        end = min(len(json_str_cleaned), e.pos + context_chars)
        snippet = json_str_cleaned[start:end]
        # Add ellipsis if snippet is truncated
        if start > 0: snippet = "..." + snippet
        if end < len(json_str_cleaned): snippet = snippet + "..."
        # Create a pointer line
        pointer = " " * (min(e.pos, context_chars) + (3 if start > 0 else 0)) + "^" # Adjust pointer for ellipsis
        error_context = f"\nProblematic Snippet (approx. char {e.pos}):\n{snippet}\n{pointer}"

        print(f"ERROR: {error_msg}{error_context}")
        return {
            "error": error_msg + error_context,
            "raw_response": response_text[:500] # Return original raw response snippet
        }
    except Exception as general_e:
        # Catch any other unexpected errors during parsing/validation
        error_msg = f"Unexpected error during parsing/validation: {str(general_e)}"
        # Print exception info directly for unexpected errors
        print(f"ERROR: {error_msg}")
        import traceback
        traceback.print_exc() # Print traceback for unexpected errors
        return {
            "error": error_msg,
            "raw_response": response_text[:500]
        }


def generate_topic_subtopic_specifications(topic_network, Z=10):
    """
    Generate Z different topic/subtopic specifications.
    
    Returns a list of dictionaries, each containing:
    - topic_mix: Dictionary mapping topic IDs to weights
    - selected_subtopics: Dictionary mapping topic IDs to selected subtopics
    """
    specifications = []
    
    for i in range(Z):
        # Sample the topic mix (existing function)
        topic_mix = sample_concept_mix(topic_network)
        
        # Sample a subtopic for each topic
        selected_subtopics = {}
        for topic_id in topic_mix:
            # Select a subtopic randomly
            subtopic = np.random.choice(TOPICS[topic_id]["subtopics"])
            selected_subtopics[topic_id] = subtopic
        
        # Store the specification
        specification = {
            "id": i + 1,
            "topic_mix": topic_mix,
            "selected_subtopics": selected_subtopics
        }
        
        specifications.append(specification)
    
    return specifications


def generate_diversity_parameter_variations(topic_subtopic_specs, N=30):
    """
    For each topic/subtopic specification, generate N variations of diversity parameters.
    
    Returns a list of dictionaries, each containing:
    - id: Sequential ID for this abstract
    - topic_spec_id: ID of the topic/subtopic specification
    - topic_mix: Dictionary mapping topic IDs to weights
    - selected_subtopics: Dictionary mapping topic IDs to selected subtopics
    - diversity_type: "no_diversity", "partial_diversity", or "full_diversity"
    - abstract_length: "short" or "long"
    - allow_topic_mention: Boolean
    - allow_subtopic_mention: Boolean
    - diversity_params: Dictionary of diversity parameters
    """
    variations = []
    id_counter = 1
    
    for spec in topic_subtopic_specs:
        for n in range(N):
            # 1. Sample diversity type
            diversity_type = np.random.choice(["no_diversity", "partial_diversity", "full_diversity"])
            
            # 2. Sample abstract length
            abstract_length = np.random.choice(["short", "long"])
            
            # 3. Sample whether topics and subtopics can be mentioned
            allow_topic_mention = np.random.choice([True, False])
            allow_subtopic_mention = np.random.choice([True, False])
            
            # 4. Sample diversity parameters based on diversity type
            diversity_params = sample_diversity_params(diversity_type)
            
            # Create the variation
            variation = {
                "id": id_counter,
                "topic_spec_id": spec["id"],
                "topic_mix": spec["topic_mix"],
                "selected_subtopics": spec["selected_subtopics"],
                "diversity_type": diversity_type,
                "abstract_length": abstract_length,
                "allow_topic_mention": allow_topic_mention,
                "allow_subtopic_mention": allow_subtopic_mention,
                "diversity_params": diversity_params
            }
            
            variations.append(variation)
            id_counter += 1
    
    return variations


# Updated function to handle sampling of diversity parameters
def sample_diversity_params(experiment_type):
    """Sample diversity parameters for an abstract based on experiment type."""
    if experiment_type == "no_diversity":
        # Return empty dict or None to indicate no diversity parameters
        return None
    elif experiment_type == "partial_diversity":
        # Sample a subset of diversity parameters (choose some randomly)
        params = {}
        all_params = list(DIVERSITY_PARAMS.keys())
        # Randomly select ~half of the parameters
        selected_params = np.random.choice(all_params, size=len(all_params)//2, replace=False)
        for param_name in selected_params:
            params[param_name] = np.random.choice(DIVERSITY_PARAMS[param_name])
        # Always include domain for consistency
        params["domain"] = np.random.choice(DOMAINS)
        return params
    else:  # Full diversity (original behavior)
        params = {}
        for param_name, options in DIVERSITY_PARAMS.items():
            params[param_name] = np.random.choice(options)
        params["domain"] = np.random.choice(DOMAINS)
        return params



    
def generate_individual_responses(variations, config):
    """
    Process each prompt individually using API calls, storing results
    with provider-prefixed keys only.

    Args:
        variations: List of variation dictionaries, each including a 'prompt'.
        config: Dictionary containing API keys, URLs, and model configurations.

    Returns:
        List of variation dictionaries, updated with provider-prefixed results
        and validation outcomes.
    """
    all_results = []
    total = len(variations)

    print(f"\nProcessing {total} variations individually...")

    # Process each variation
    for i, variation_original in enumerate(variations):
        prompt = variation_original.get("prompt", "Missing prompt")
        if prompt == "Missing prompt":
             print(f"Warning: Variation {i+1} missing prompt. Skipping.")
             continue

        # Process this variation with each configured model
        for model_config in config.get("models", []):
            model_name = model_config.get("name", "unknown_model")
            provider = model_config.get("provider", "unknown_provider")

            print(f"Processing variation {i+1}/{total} with model {model_name} ({provider})...")

            # Create a copy to store results for this specific model/provider
            model_variation = variation_original.copy()
            model_variation["model"] = model_name
            model_variation["provider"] = provider

            try:
                # Make the API call based on provider
                if provider == "groq":
                    api_key = config.get("groq", {}).get("api_key")
                    api_url = config.get("groq", {}).get("api_url", "https://api.groq.com/openai/v1/chat/completions")
                    if not api_key:
                        raise ValueError("Groq API key not found in config")

                    headers = {
                        "Authorization": f"Bearer {api_key}",
                        "Content-Type": "application/json"
                    }
                    data = {
                        "model": model_name,
                        "messages": [{"role": "user", "content": prompt}],
                        "temperature": model_config.get("temperature", 0.7),
                        "max_tokens": model_config.get("max_tokens", 8000) # Use model-specific max_tokens
                    }

                    response = requests.post(api_url, headers=headers, json=data)

                    if response.status_code != 200:
                        error_text = response.text[:500] # Limit error text length
                        model_variation[f"{provider}_error"] = f"API Error: {response.status_code} - {error_text}"
                        model_variation["raw_response"] = response.text # Store full raw response on error
                    else:
                        response_data = response.json()
                        content = response_data.get("choices", [{}])[0].get("message", {}).get("content", "")

                        # Parse the response content
                        parsed_data = parse_llm_response(content)

                        # Store results with provider-prefixed keys ONLY
                        model_variation[f"{provider}_title"] = parsed_data.get("title")
                        model_variation[f"{provider}_abstract"] = parsed_data.get("abstract")
                        model_variation[f"{provider}_keywords"] = parsed_data.get("keywords")
                        # Store parsing error (if any) under the provider prefix
                        model_variation[f"{provider}_error"] = parsed_data.get("error")
                        # Store raw response details (consider storing less if large)
                        model_variation["raw_response_metadata"] = {
                            "usage": response_data.get("usage"),
                            "id": response_data.get("id"),
                            "model": response_data.get("model"),
                            # Add other relevant metadata, avoid storing full 'content' again
                        }
                        # Optionally store the raw parsed data if needed for debugging,
                        # but be mindful of data size
                        # model_variation[f"{provider}_raw_parsed"] = parsed_data


                        # Apply validation using the updated function
                        # Pass the 'model_variation' dict which now contains the prefixed results
                        if not model_variation[f"{provider}_error"]: # Only validate if parsing succeeded
                             validation_results = validate_response(model_variation, provider)
                             model_variation.update(validation_results)
                        else:
                             # Add flag indicating validation was skipped due to parsing error
                             model_variation[f"{provider}_validation_skipped"] = True


                # Add other providers here with elif statements as needed
                # elif provider == "openai":
                #     # OpenAI API call logic...
                #     # parsed_data = parse_llm_response(openai_content)
                #     # model_variation[f"{provider}_title"] = parsed_data.get("title")
                #     # ... etc ...
                #     # validation_results = validate_response(model_variation, provider)
                #     # model_variation.update(validation_results)
                #     pass # Placeholder

                else:
                    model_variation[f"{provider}_error"] = f"Unsupported provider: {provider}"

            except Exception as e:
                # Catch potential errors during API call or processing
                print(f"ERROR: Exception during processing for variation {i+1}, model {model_name}: {e}")
                import traceback
                traceback.print_exc() # Print traceback for debugging
                model_variation[f"{provider}_error"] = f"Processing error: {str(e)}"

            # Add the result (including potential errors) to the list
            all_results.append(model_variation)

            # Optional: add delay between requests to avoid rate limits
            time.sleep(config.get("request_delay", 0.5)) # Use configurable delay

    return all_results


def prepare_batch_requests(prompts, config, model_config, batch_file_path):
    """
    Prepare a JSONL batch file for Groq API for a specific model,
    writing to the provided file path.

    Args:
        prompts (list): List of prompt strings.
        config (dict): General configuration dictionary.
        model_config (dict): Configuration for the specific model.
        batch_file_path (str): The full path to write the batch request file to.
    """
    batch_requests = []
    model_name = model_config["name"]

    for idx, prompt in enumerate(prompts):
        # Create a request entry for each prompt
        request = {
            "custom_id": f"{model_name}-abstract-{idx+1}", # Used to map results back
            "method": "POST",
            "url": "/v1/chat/completions", # Standard endpoint for chat
            "body": {
                "model": model_name,
                "messages": [{"role": "user", "content": prompt}],
                "temperature": model_config.get("temperature", 0.7),
                "max_tokens": model_config.get("max_tokens", 8000)
                # Add other parameters like 'top_p', 'stop' if needed in model_config
            }
        }
        # Convert the dictionary to a JSON string for the JSONL file
        batch_requests.append(json.dumps(request))

    # Write requests to the specified JSONL file path
    # The 'batch_file_path' is now passed directly from process_groq_batch
    try:
        with open(batch_file_path, 'w', encoding='utf-8') as f:
            for request_json_line in batch_requests:
                f.write(request_json_line + '\n')
        print(f"Batch request file prepared: {batch_file_path}")
    except IOError as e:
        print(f"ERROR: Failed to write batch request file {batch_file_path}: {e}")
        raise # Re-raise the exception so the calling function knows writing failed

    return batch_file_path # Return the path (useful for confirmation)

    
def process_groq_batch(variations, config, model_config):
    """
    Process a batch of prompts using Groq's batch API for a specific model,
    storing results with provider-prefixed keys only.

    Args:
        variations: List of variation dictionaries for this batch.
        config: Dictionary containing API keys, URLs.
        model_config: Dictionary for the specific model being processed.
        batch_file_path

    Returns:
        List of variation dictionaries from the batch, updated with provider-prefixed
        results and validation outcomes.
    """
    groq_api_key = config.get("groq", {}).get("api_key")
    if not groq_api_key:
        print("ERROR: Groq API key not found in config for batch processing.")
        # Return variations marked with error
        results = []
        for var in variations:
             var_copy = var.copy()
             var_copy["groq_error"] = "Configuration error: Missing API Key"
             results.append(var_copy)
        return results

    model_name = model_config.get("name", "unknown_model")
    provider = "groq" # Hardcoded for this function

    # Prepare prompts list for batch file creation
    prompts = [variation.get("prompt", "Missing prompt") for variation in variations]

    # Prepare batch file for this model
    # Use a unique filename to avoid conflicts if run in parallel
    batch_file_path = f"batch_requests_{model_name}_{time.time_ns()}.jsonl"
    try:
        batch_file_path = prepare_batch_requests(prompts, config, model_config, batch_file_path)
    except Exception as e:
        print(f"ERROR: Failed to prepare batch request file: {e}")
        results = []
        for var in variations:
             var_copy = var.copy()
             var_copy["groq_error"] = f"Batch prep error: {e}"
             results.append(var_copy)
        return results


    print(f"\nProcessing batch for model {model_name}...")
    batch_results_list = []
    batch_error = None

    try:
        # Upload batch file
        print(f"Uploading batch file {batch_file_path}...")
        file_upload_response = upload_file_to_groq(groq_api_key, batch_file_path)
        input_file_id = file_upload_response["id"]
        print(f"Uploaded file ID: {input_file_id}")

        # Create batch job
        print(f"Creating batch job...")
        batch_job_response = create_batch_job(groq_api_key, input_file_id, completion_window="72h")
        batch_id = batch_job_response["id"]
        print(f"Batch job ID: {batch_id}")

        # Poll for completion
        print(f"Polling for batch completion (ID: {batch_id})...")
        batch_complete = False
        poll_interval = config.get("batch_poll_interval", 60) # seconds, configurable
        max_poll_attempts = config.get("batch_max_poll_attempts", 60) # e.g., 60 attempts * 60s = 1 hour
        attempt = 0

        output_file_id = None
        while not batch_complete and attempt < max_poll_attempts:
            attempt += 1
            try:
                batch_status_response = check_batch_status(groq_api_key, batch_id)
                status = batch_status_response["status"]
                print(f"Batch status: {status} (Attempt {attempt}/{max_poll_attempts})")

                if status == "completed":
                    batch_complete = True
                    output_file_id = batch_status_response.get("output_file_id")
                    error_file_id = batch_status_response.get("error_file_id") # Check for error file
                    request_counts = batch_status_response.get("request_counts", {})
                    completed_count = request_counts.get('completed', 0)
                    failed_count = request_counts.get('failed', 0)
                    print(f"Batch completed! Output File ID: {output_file_id}, Error File ID: {error_file_id}")
                    print(f"Request Counts: {completed_count} succeeded, {failed_count} failed.")
                    if error_file_id:
                         print("WARNING: Batch job reported an error file.")
                         # Potentially download and inspect error file here
                    break # Exit polling loop

                elif status in ["failed", "expired", "cancelling", "cancelled"]:
                    batch_error = f"Batch job {status}. Errors: {batch_status_response.get('errors')}"
                    print(f"ERROR: {batch_error}")
                    break # Exit polling loop

                else: # validating, in_progress, finalizing
                    print(f"Waiting {poll_interval} seconds before checking again...")
                    time.sleep(poll_interval)

            except Exception as poll_e:
                print(f"ERROR: Exception during batch status polling: {poll_e}")
                # Decide whether to break or continue polling
                time.sleep(poll_interval * 2) # Longer wait after error

        if not batch_complete and attempt >= max_poll_attempts:
             batch_error = f"Batch job did not complete after {max_poll_attempts} attempts."
             print(f"ERROR: {batch_error}")

        # --- Process results if batch completed and output file exists ---
        if batch_complete and output_file_id:
            results_local_file = f"batch_results_{model_name}_{batch_id}.jsonl"
            try:
                print(f"Downloading results file {output_file_id} to {results_local_file}...")
                results_local_file = download_batch_results(groq_api_key, output_file_id, results_local_file)

                # Read and process batch results line by line
                processed_indices = set()
                with open(results_local_file, 'r', encoding='utf-8') as f:
                    for line in f:
                        try:
                            result_line = json.loads(line)
                            custom_id = result_line.get("custom_id") # e.g., "llama-3.1-8b-instant-abstract-1"
                            response_body = result_line.get("response", {}).get("body", {})
                            error_body = result_line.get("error") # Check for line-specific errors

                            if not custom_id:
                                print(f"Warning: Skipping result line with missing custom_id: {line[:100]}...")
                                continue

                            # Extract original index from custom_id
                            # Assumes format "{model_name}-abstract-{original_index+1}"
                            try:
                                import re
                                match = re.search(r'-(\d+)$', custom_id)
                                if match:
                                    original_idx = int(match.group(1)) - 1
                                    processed_indices.add(original_idx)
                                else:
                                    # Handle error: couldn't parse index
                                    print(f"Warning: Could not parse index from custom_id '{custom_id}'. Skipping line.")
                                    continue
                            except (IndexError, ValueError) as e:
                                print(f"Warning: Could not parse index from custom_id '{custom_id}': {e}. Skipping line.")
                                continue

                            # Match with the original variation using the index
                            if 0 <= original_idx < len(variations):
                                variation_copy = variations[original_idx].copy()
                                variation_copy["model"] = model_name
                                variation_copy["provider"] = provider
                                variation_copy["original_index"] = original_idx # Keep track

                                # Store raw response line for reference
                                variation_copy["raw_response"] = result_line

                                if error_body:
                                    # Handle error reported for this specific request in the batch
                                    print(f"Warning: Error reported for custom_id {custom_id}: {error_body}")
                                    variation_copy[f"{provider}_error"] = f"Batch API Error: {error_body.get('message', str(error_body))}"
                                else:
                                    # Extract content if no error
                                    content = response_body.get("choices", [{}])[0].get("message", {}).get("content", "")
                                    parsed_data = parse_llm_response(content)

                                    # Store results with provider-prefixed keys ONLY
                                    variation_copy[f"{provider}_title"] = parsed_data.get("title")
                                    variation_copy[f"{provider}_abstract"] = parsed_data.get("abstract")
                                    variation_copy[f"{provider}_keywords"] = parsed_data.get("keywords")
                                    variation_copy[f"{provider}_error"] = parsed_data.get("error") # Parsing error
                                    # Store raw response metadata (usage, etc.)
                                    variation_copy["raw_response_metadata"] = {
                                        "usage": response_body.get("usage"),
                                        "id": response_body.get("id"),
                                        "model": response_body.get("model"),
                                        "status_code": result_line.get("response", {}).get("status_code")
                                    }

                                    # Apply validation if parsing succeeded
                                    if not variation_copy[f"{provider}_error"]:
                                        validation_results = validate_response(variation_copy, provider)
                                        variation_copy.update(validation_results)
                                    else:
                                         variation_copy[f"{provider}_validation_skipped"] = True

                                batch_results_list.append(variation_copy)
                            else:
                                print(f"Warning: Index {original_idx} from custom_id '{custom_id}' is out of range for variations list (size {len(variations)}).")

                        except json.JSONDecodeError as json_e:
                            print(f"Warning: Failed to decode JSON from result line: {json_e}. Line: {line[:100]}...")
                        except Exception as line_proc_e:
                             print(f"ERROR: Unexpected error processing result line for custom_id {custom_id}: {line_proc_e}")
                             # Optionally add an error entry for this variation index if possible

                # Add error entries for variations that were not in the results file
                for i, original_variation in enumerate(variations):
                    if i not in processed_indices:
                        print(f"Warning: Variation index {i} was not found in the batch results file.")
                        error_variation = original_variation.copy()
                        error_variation["model"] = model_name
                        error_variation["provider"] = provider
                        error_variation[f"{provider}_error"] = "No response found in batch results file"
                        batch_results_list.append(error_variation)

            except Exception as download_e:
                 print(f"ERROR: Failed to download or process batch results file: {download_e}")
                 batch_error = f"Results processing error: {download_e}" # Set batch error

        elif not output_file_id:
             # Handle cases where batch completed but had no output file ID (e.g., all failed)
             if not batch_error: # If no specific batch error was already recorded
                 batch_error = "Batch completed but no output file ID was generated (likely all requests failed)."
             print(f"ERROR: {batch_error}")


    except Exception as batch_proc_e:
        print(f"ERROR: Major exception during batch processing pipeline for model {model_name}: {batch_proc_e}")
        import traceback
        traceback.print_exc()
        batch_error = f"Batch pipeline error: {batch_proc_e}"

    # If the whole batch process failed, mark all variations with the batch error
    if batch_error and not batch_results_list: # Avoid overwriting if some results were processed
        print(f"Marking all variations in this batch with error: {batch_error}")
        for variation in variations:
            error_variation = variation.copy()
            error_variation["model"] = model_name
            error_variation["provider"] = provider
            error_variation[f"{provider}_error"] = batch_error
            batch_results_list.append(error_variation)

    # Clean up batch file? (Optional)
    # import os
    # try:
    #     if os.path.exists(batch_file_path):
    #         os.remove(batch_file_path)
    #         print(f"Cleaned up batch file: {batch_file_path}")
    # except OSError as e:
    #     print(f"Warning: Could not remove batch file {batch_file_path}: {e}")

    return batch_results_list




def validate_forbidden_words(variation, combined_text, provider_prefix):
    """Helper function to validate forbidden words in the text
    
    Returns:
        Dictionary of validation results
    """
    results = {}
    
    # Check for forbidden topics (if they were forbidden)
    if not variation["allow_topic_mention"]:
        found_forbidden_topics = []
        for topic_id in variation["topic_mix"]:
            topic_name = TOPICS[topic_id]["name"].lower()
            if topic_name in combined_text:
                found_forbidden_topics.append(topic_name)
        
        results[f"{provider_prefix}_contains_forbidden_topics"] = len(found_forbidden_topics) > 0
        results[f"{provider_prefix}_found_forbidden_topics"] = found_forbidden_topics
    
    # Check for forbidden subtopics (if they were forbidden)
    if not variation["allow_subtopic_mention"]:
        found_forbidden_subtopics = []
        for topic_id, subtopic in variation["selected_subtopics"].items():
            subtopic_lower = subtopic.lower()
            if subtopic_lower in combined_text:
                found_forbidden_subtopics.append(subtopic_lower)
        
        results[f"{provider_prefix}_contains_forbidden_subtopics"] = len(found_forbidden_subtopics) > 0
        results[f"{provider_prefix}_found_forbidden_subtopics"] = found_forbidden_subtopics
    
    return results


        



def upload_file_to_groq(api_key, file_path):
    """Upload a file to Groq API."""
    url = "https://api.groq.com/openai/v1/files"
    
    headers = {
        "Authorization": f"Bearer {api_key}"
    }
    
    # Prepare the file and form data
    files = {
        "file": ("batch_file.jsonl", open(file_path, "rb"))
    }
    
    data = {
        "purpose": "batch"
    }
    
    # Make the POST request
    response = requests.post(url, headers=headers, files=files, data=data)
    
    if response.status_code != 200:
        raise Exception(f"File upload failed: {response.text}")
    
    return response.json()


def create_batch_job(api_key, input_file_id, completion_window="72h"):
    """Create a batch job with Groq API."""
    url = "https://api.groq.com/openai/v1/batches"
    
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    data = {
        "input_file_id": input_file_id,
        "endpoint": "/v1/chat/completions",
        "completion_window": completion_window
    }
    
    response = requests.post(url, headers=headers, json=data)
    
    if response.status_code != 200:
        raise Exception(f"Batch job creation failed: {response.text}")
    
    return response.json()


def check_batch_status(api_key, batch_id):
    """Check the status of a batch job."""
    url = f"https://api.groq.com/openai/v1/batches/{batch_id}"
    
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    
    response = requests.get(url, headers=headers)
    
    if response.status_code != 200:
        raise Exception(f"Batch status check failed: {response.text}")
    
    return response.json()


def download_batch_results(api_key, file_id, output_file):
    """Download batch results from Groq API."""
    url = f"https://api.groq.com/openai/v1/files/{file_id}/content"
    
    headers = {
        "Authorization": f"Bearer {api_key}"
    }
    
    response = requests.get(url, headers=headers)
    
    if response.status_code != 200:
        raise Exception(f"Download failed: {response.text}")
    
    # Write the content to a file
    with open(output_file, 'wb') as f:
        f.write(response.content)
    
    print(f"File downloaded successfully to {output_file}")
    return output_file


def create_prompt(variation):
    """Create a prompt based on a complete parameter variation"""
    # Extract parameters from the variation
    topic_mix = variation["topic_mix"]
    selected_subtopics = variation["selected_subtopics"]
    diversity_type = variation["diversity_type"]
    abstract_length = variation["abstract_length"]
    allow_topic_mention = variation["allow_topic_mention"]
    allow_subtopic_mention = variation["allow_subtopic_mention"]
    diversity_params = variation["diversity_params"]
    
    # Prepare topic information
    formatted_topics = []
    topic_names = []
    forbidden_words = []
    forbidden_subtopic_words = []
    
    for topic_id, weight in topic_mix.items():
        topic_name = TOPICS[topic_id]["name"]
        topic_names.append(topic_name)
        
        # Add topic words to forbidden list only if not allowed
        if not allow_topic_mention:
            forbidden_words.append(topic_name.lower())
            forbidden_words.extend([word.lower() for word in topic_name.split()])
        
        # Get the subtopic
        subtopic = selected_subtopics[topic_id]
        
        # Add subtopic to forbidden list only if not allowed
        if not allow_subtopic_mention:
            forbidden_subtopic_words.append(subtopic.lower())
            forbidden_subtopic_words.extend([word.lower() for word in subtopic.split()])
        
        # Format the topic with percentage
        percentage = int(weight * 100)
        formatted_topics.append(f"{topic_name} (specifically {subtopic}): {percentage}% focus")
    
    # Remove duplicates from forbidden words
    forbidden_words = list(set(forbidden_words))
    forbidden_subtopic_words = list(set(forbidden_subtopic_words))
    
    # Format for prompt
    topics_text = "\n".join([f"- {t}" for t in formatted_topics])
    forbidden_words_str = ", ".join([f'"{w}"' for w in forbidden_words]) if forbidden_words else "None - you may use topic terms"
    forbidden_subtopic_words_str = ", ".join([f'"{w}"' for w in forbidden_subtopic_words]) if forbidden_subtopic_words else "None - you may use subtopic terms"
    
    # Determine core topic description
    num_topics = len(topic_mix)
    core_topic_names = [TOPICS[t]['name'] for t in topic_mix]
    focus_description = ' AND '.join(core_topic_names) if num_topics <= 2 else 'these topics (' + ', '.join(core_topic_names) + ')'
    
    # Set abstract length requirement
    min_words = 250 if abstract_length == "short" else 450
    max_words = 350 if abstract_length == "short" else 550
    length_requirement = f"MINIMUM {min_words} words and MAXIMUM {max_words} words"
    
    # Build diversity sections based on diversity type and parameters
    diversity_content = ""
    additional_attributes = ""
    linguistics_char = ""
    forbidden_domain = ""
    
    if diversity_type != "no_diversity" and diversity_params:
        if "domain" in diversity_params:
            diversity_content += f"- Domain application: {diversity_params['domain']}\n"
            forbidden_domain = f"- DOMAIN: Do not use the word \"{diversity_params['domain']}\" explicitly in the abstract or title\n"
            
        if "methodological_approaches" in diversity_params:
            diversity_content += f"- Methodology: {diversity_params['methodological_approaches']}\n"
            
        if "rhetorical_structures" in diversity_params:
            diversity_content += f"- Rhetorical style: {diversity_params['rhetorical_structures']}.\n"
        
        # Build additional attributes section if those params are present
        additional_attrs = []
        if "concept_granularity" in diversity_params:
            additional_attrs.append(f"- Concept granularity: {diversity_params['concept_granularity']} (Reflects in the level of detail in findings)")
        if "interdisciplinary_orientation" in diversity_params:
            additional_attrs.append(f"- Interdisciplinary orientation: {diversity_params['interdisciplinary_orientation']} (Reflected if multiple topics are distinct)")
        if "temporal_context" in diversity_params:
            additional_attrs.append(f"- Temporal context: {diversity_params['temporal_context']} (Use appropriate tense/phrasing)")
        
        if additional_attrs:
            additional_attributes = "ADDITIONAL PAPER ATTRIBUTES TO REFLECT:\n" + "\n".join(additional_attrs) + "\n\n"
        
        # Build linguistic characteristics section if those params are present
        linguistic_attrs = []
        if "terminology_density" in diversity_params:
            linguistic_attrs.append(f"- Terminology density: {diversity_params['terminology_density']}")
        if "concept_blending" in diversity_params:
            linguistic_attrs.append(f"- Concept blending approach: {diversity_params['concept_blending']}")
        
        if linguistic_attrs:
            linguistics_char = "LINGUISTIC CHARACTERISTICS TO EMBODY:\n" + "\n".join(linguistic_attrs) + "\n\n"

    # Create the prompt
    prompt = f"""
You are an academic expert simulating the creation of a research abstract. 
Your task is to generate ONE research abstract that fits a specific profile.

**CRITICAL REQUIREMENT: The generated 'abstract' field's text MUST be {length_requirement} long.** Do not generate summaries outside this range.

Your paper synthesizes the following topics. Adhere strictly to this distribution:
{topics_text}

VOCABULARY RESTRICTIONS:
- FORBIDDEN TOPIC WORDS: {forbidden_words_str}
- FORBIDDEN SUBTOPIC WORDS: {forbidden_subtopic_words_str}
{forbidden_domain}

REQUIRED ABSTRACT CONTENT GUIDELINES: 
- The study's focus: {focus_description}
{diversity_content}


{additional_attributes}{linguistics_char}MANDATORY INSTRUCTIONS:
1. **Generate ONE academic abstract where the 'abstract' text is {length_requirement}.**
2. Do NOT use any FORBIDDEN TOPIC WORDS in your abstract or title.
3. Do NOT use any FORBIDDEN SUBTOPIC WORDS in your abstract or title. 
4. Strictly follow the content guidelines.
5. Adhere to the Topic Distribution percentages.
6. Include at least 3-5 specific, concrete findings, methods, or implications. AVOID VAGUENESS. Elaborate on points.
7. Ensure the abstract is academically plausible and internally consistent.
8. Do NOT mention the percentages, parameters, instructions, or section headers explicitly in the output abstract text.
9. Do NOT write a short summary; generate a detailed, well-developed abstract fulfilling the word count requirements.
10. **The final 'abstract' field content MUST meet the {length_requirement} requirement.**

OUTPUT FORMAT:  
Return ONLY a single, valid JSON object containing the keys 'title', 'abstract', and 'keywords'.
- The 'title' value must be a descriptive academic title of 10-20 words
- The 'abstract' value must be a single string containing the full abstract text ({length_requirement}).
- The 'keywords' value must be a list of 4-6 relevant strings.
- Do NOT include ```json markdown wrappers, comments, explanations, or any text outside the JSON structure.

Example JSON structure (fill with generated content):
{{
  "title": "A Plausible and Specific Academic Title Reflecting the Content",
  "abstract": "Abstract text here... More text here... (Ensuring total abstract is within {length_requirement})",
  "keywords": ["Relevant Keyword 1", "Keyword 2", "Topic Keyword", "Method Keyword", "Domain Keyword"]
}}
"""

    return prompt


# --- Updated Validation Function ---
def validate_response(variation: dict, provider_prefix: str) -> dict:
    """
    Comprehensive validation function that checks constraints on generated content,
    reading input fields using the provider prefix.

    Args:
        variation: The dictionary containing original constraints AND the
                   provider-prefixed results (e.g., 'groq_title', 'groq_abstract').
        provider_prefix: The provider name (e.g., 'groq', 'openai') for field access.

    Returns:
        Dictionary of validation results with provider-prefixed keys.
    """
    validation_results = {}

    # Read title and abstract using the provider prefix
    abstract_text = variation.get(f"{provider_prefix}_abstract", "").lower()
    title_text = variation.get(f"{provider_prefix}_title", "").lower()

    # Skip validation if no abstract was generated for this provider
    if not abstract_text:
        # Return empty results or specific flags if needed
        return {f"{provider_prefix}_validation_skipped": True}

    combined_text = abstract_text + " " + title_text

    # --- Check word count ---
    word_count = len(abstract_text.split())
    min_words = 250 if variation.get("abstract_length") == "short" else 450
    max_words = 350 if variation.get("abstract_length") == "short" else 550
    validation_results[f"{provider_prefix}_meets_length_requirements"] = min_words <= word_count <= max_words
    validation_results[f"{provider_prefix}_word_count"] = word_count

    # --- Check for forbidden words ---
    # Pass the combined text derived from prefixed fields
    forbidden_words_results = validate_forbidden_words(variation, combined_text, provider_prefix)
    validation_results.update(forbidden_words_results)

    # --- Add other validations as needed ---
    # e.g., keyword count, title length, checking specific content based on diversity params
    keywords = variation.get(f"{provider_prefix}_keywords", [])
    if isinstance(keywords, list):
         validation_results[f"{provider_prefix}_keyword_count"] = len(keywords)
         validation_results[f"{provider_prefix}_keywords_valid_format"] = True
    else:
         validation_results[f"{provider_prefix}_keyword_count"] = 0
         validation_results[f"{provider_prefix}_keywords_valid_format"] = False


    return validation_results



def generate_synthetic_dataset(config: dict) -> pd.DataFrame:
    """
    Generate a synthetic dataset using systematic parameter variation,
    handling results with provider-prefixed keys and calculating relevant statistics.

    Args:
        config: Dictionary containing generation parameters, model configs, API keys, etc.

    Returns:
        A pandas DataFrame containing the generated data and metadata.
    """
    # Initial setup - create topic network and variations
    print(f"Creating topic network...")
    topic_network = create_topic_network()

    num_topic_specs = config.get('num_topic_specs', 10) # Default value
    print(f"\nGenerating {num_topic_specs} topic/subtopic specifications...")
    topic_subtopic_specs = generate_topic_subtopic_specifications(
        topic_network,
        num_topic_specs
    )

    variations_per_spec = config.get('variations_per_spec', 20) # Default value
    print(f"\nGenerating {variations_per_spec} diversity parameter variations for each specification...")
    variations = generate_diversity_parameter_variations(
        topic_subtopic_specs,
        variations_per_spec
    )

    print(f"\nPreparing {len(variations)} prompts for processing...")
    # Create prompts for each variation
    for variation in tqdm(variations, desc="Preparing prompts"):
        try:
             variation["prompt"] = create_prompt(variation)
        except Exception as e:
             print(f"Error creating prompt for variation ID {variation.get('id')}: {e}")
             variation["prompt"] = "Error generating prompt" # Mark problematic variations

    # Choose processing mode based on configuration
    all_results = []
    batch_mode = config.get("batch_mode", False) # Default to non-batch

    if batch_mode:
        print("\nUsing BATCH mode for generation...")
        # Process each model configured for batch mode
        for model_config in config.get("models", []):
            model_name = model_config.get("name", "unknown_model")
            provider = model_config.get("provider", "unknown_provider")

            print(f"\nProcessing model: {model_name} via {provider} (Batch Mode)...")

            if provider == "groq":
                # Pass only the variations relevant to this batch call
                # (assuming process_groq_batch handles the full list internally if needed,
                # or expects just the variations list)
                batch_results = process_groq_batch(variations, config, model_config)
                all_results.extend(batch_results)

            # Add other providers supporting batch mode here
            # elif provider == "openai" and config.get("openai",{}).get("supports_batch"):
            #     batch_results = process_openai_batch(variations, config, model_config)
            #     all_results.extend(batch_results)
            else:
                 print(f"Warning: Batch mode not implemented or configured for provider '{provider}'. Skipping.")

    else: # Non-batch mode
        print("\nUsing NON-BATCH mode for generation...")
        # Process with individual API calls
        # generate_individual_responses handles looping through models internally
        individual_results = generate_individual_responses(variations, config)
        all_results.extend(individual_results)

    # --- Create and save the dataset ---
    if not all_results:
         print("\nERROR: No results were generated. Cannot create DataFrame.")
         return pd.DataFrame() # Return empty DataFrame

    print("\nCreating DataFrame from results...")
    try:
        df = pd.DataFrame(all_results)
    except Exception as e:
        print(f"\nERROR: Failed to create DataFrame from results: {e}")
        print("Dumping raw results for inspection (first 5):")
        print(all_results[:5])
        return pd.DataFrame() # Return empty DataFrame

    # --- Post-processing and Analysis ---
    print("Post-processing DataFrame...")

    # Add topic_mix_str for easier reading if 'topic_mix' exists and is dict
    if 'topic_mix' in df.columns:
        df['topic_mix_str'] = df['topic_mix'].apply(
            lambda mix: ", ".join([f"{TOPICS.get(k, {'name':k})['name']}: {v*100:.0f}%" for k, v in mix.items()])
            if isinstance(mix, dict) else "Invalid/Missing Topic Mix"
        )
    else:
        df['topic_mix_str'] = "N/A"


    # Save the dataset
    output_file = config.get("output_file", "synthetic_abstracts_output.csv")
    feather_output_file = output_file.replace('.csv', '.feather')
    print(f"\nSaving dataset...")
    try:
        df.to_csv(output_file, index=False, encoding='utf-8')
        print(f"Dataset saved to {output_file}")
        try:
            df.to_feather(feather_output_file)
            print(f"Dataset saved to {feather_output_file}")
        except Exception as fe:
            print(f"Warning: Could not save to Feather format: {fe}")
    except Exception as e:
        print(f"\nError saving dataset: {e}")

    # --- Print Summary Statistics ---
    print("\n" + "="*30)
    print("  Generation Summary")
    print("="*30)

    if df.empty:
        print("DataFrame is empty, cannot generate summary.")
        return df

    total_rows = len(df)
    print(f"Total rows (Variations * Models): {total_rows}")
    if 'provider' not in df.columns:
         print("ERROR: 'provider' column missing, cannot generate provider-specific stats.")
         return df

    # Count success by provider
    print("\n--- Success Rate by Provider ---")
    providers = df['provider'].unique()
    for provider in providers:
        if pd.isna(provider): continue # Skip if provider is NaN
        provider_df = df[df['provider'] == provider].copy() # Use copy to avoid SettingWithCopyWarning
        provider_error_col = f'{provider}_error'

        if provider_error_col not in provider_df.columns:
             print(f"Warning: Error column '{provider_error_col}' not found for provider '{provider}'. Cannot calculate success rate.")
             continue

        # Ensure the error column is treated correctly (NaN is success)
        # Fill NaNs in the specific error column with a placeholder indicating success for calculation
        provider_df[f'{provider}_error_present'] = ~provider_df[provider_error_col].isna()
        success_count = provider_df[f'{provider}_error_present'].eq(False).sum()

        total_provider = len(provider_df)
        if total_provider > 0:
            print(f"{str(provider).capitalize()} Provider:")
            print(f"  Total Results: {total_provider}")
            print(f"  Successful Generations (No Error): {success_count} ({success_count/total_provider*100:.1f}%)")
            print(f"  Failed Generations (Error Present): {total_provider - success_count} ({(total_provider - success_count)/total_provider*100:.1f}%)")
        else:
            print(f"{str(provider).capitalize()} Provider: No results found.")

    # Statistics by parameters (using provider-specific errors)
    print("\n--- Success Rate by Diversity Type ---")
    if 'diversity_type' in df.columns:
        for dtype in df['diversity_type'].unique():
            if pd.isna(dtype): continue
            type_df = df[df['diversity_type'] == dtype]
            count = len(type_df)
            if count > 0:
                 # Calculate success by checking the provider-specific error for each row
                 success_count = type_df.apply(
                     lambda row: pd.isna(row.get(f"{row.get('provider', 'unknown')}_error", "Error Present")),
                     axis=1
                 ).sum()
                 print(f"  {dtype}: {count} total, {success_count} successful ({success_count/count*100:.1f}%)")
            else:
                 print(f"  {dtype}: 0 total")
    else:
        print("  'diversity_type' column not found.")


    print("\n--- Success Rate by Abstract Length ---")
    if 'abstract_length' in df.columns:
        for length in df['abstract_length'].unique():
             if pd.isna(length): continue
             length_df = df[df['abstract_length'] == length]
             count = len(length_df)
             if count > 0:
                 # Calculate success by checking the provider-specific error for each row
                 success_count = length_df.apply(
                     lambda row: pd.isna(row.get(f"{row.get('provider', 'unknown')}_error", "Error Present")),
                     axis=1
                 ).sum()
                 print(f"  {length}: {count} total, {success_count} successful ({success_count/count*100:.1f}%)")
             else:
                  print(f"  {length}: 0 total")
    else:
        print("  'abstract_length' column not found.")


    print("\n--- Input Parameter Counts ---")
    if 'allow_topic_mention' in df.columns:
        print(f"  Topic allowed: {(df['allow_topic_mention']==True).sum()} variations")
        print(f"  Topic forbidden: {(df['allow_topic_mention']==False).sum()} variations")
    if 'allow_subtopic_mention' in df.columns:
        print(f"  Subtopic allowed: {(df['allow_subtopic_mention']==True).sum()} variations")
        print(f"  Subtopic forbidden: {(df['allow_subtopic_mention']==False).sum()} variations")

    # Constraint violations by provider
    print("\n--- Constraint Violations by Provider (Among Successful Generations) ---")
    for provider in providers:
        if pd.isna(provider): continue
        provider_prefix = str(provider)
        # Filter for successful generations for this provider
        success_df = df[
            (df['provider'] == provider) &
            (df[f'{provider_prefix}_error'].isna())
        ].copy() # Use copy

        total_successful = len(success_df)
        print(f"\n{provider_prefix.capitalize()} Provider (Successful: {total_successful}):")

        if total_successful == 0:
            print("  No successful generations to check for constraint violations.")
            continue

        # Check forbidden topic violations
        col_forbidden_topic = f'{provider_prefix}_contains_forbidden_topics'
        if col_forbidden_topic in success_df.columns:
            forbidden_topic_violations = success_df[col_forbidden_topic].eq(True).sum()
            print(f"  Mention forbidden topics: {forbidden_topic_violations} ({forbidden_topic_violations/total_successful*100:.1f}%)")
        else:
            print(f"  Column '{col_forbidden_topic}' not found.")

        # Check forbidden subtopic violations
        col_forbidden_subtopic = f'{provider_prefix}_contains_forbidden_subtopics'
        if col_forbidden_subtopic in success_df.columns:
            forbidden_subtopic_violations = success_df[col_forbidden_subtopic].eq(True).sum()
            print(f"  Mention forbidden subtopics: {forbidden_subtopic_violations} ({forbidden_subtopic_violations/total_successful*100:.1f}%)")
        else:
             print(f"  Column '{col_forbidden_subtopic}' not found.")

        # Check length requirement violations
        col_length_met = f'{provider_prefix}_meets_length_requirements'
        if col_length_met in success_df.columns:
            # Count where requirement is False
            length_violations = success_df[col_length_met].eq(False).sum()
            print(f"  Don't meet length reqs: {length_violations} ({length_violations/total_successful*100:.1f}%)")
        else:
            print(f"  Column '{col_length_met}' not found.")

    print("\n" + "="*30)
    print("  Summary End")
    print("="*30)

    return df

# --- Updated Configuration ---

CONFIG = {
    "num_topic_specs": 15,      
    "variations_per_spec": 30,  
    "output_file": "synthetic_abstracts_multi_model.csv",
    "batch_mode": False,  # Flag to determine batch or non-batch operation
    
    "models": [
        {
            "name": "llama-3.1-8b-instant",
            "provider": "groq",
            "max_tokens": 8000,
            "temperature": 0.7,
        },

        {
            "name": "llama-3.3-70b-versatile",
            "provider": "groq",
            "max_tokens": 8000,
            "temperature": 0.7,
        }
    ],
    
    "groq": {
        "api_url": "https://api.groq.com/openai/v1/chat/completions",
        "api_key": groq_api_key,
    }
}


results_df = generate_synthetic_dataset(CONFIG)