#!/usr/bin/env python3
"""
Qwen 3.6-Full Quality Evaluation via vLLM API.
Handles reasoning models properly (content field is populated after reasoning).
"""

import json
import sys
import os
import requests
import re

API_URL = "http://localhost:11439/v1/chat/completions"
MODEL = "qwen-3.6-full"

def query(prompt, max_tokens=4096):
    """Send query to vLLM and extract content from response.
    
    Handles reasoning models: content appears after reasoning completes.
    Falls back to reasoning field if content is empty.
    """
    messages = [{"role": "user", "content": prompt}]
    
    resp = requests.post(API_URL, json={
        "model": MODEL,
        "messages": messages,
        "max_tokens": max_tokens,
        "temperature": 0.0,
    })
    resp.raise_for_status()
    data = resp.json()
    
    msg = data["choices"][0]["message"]
    content = msg.get("content", "")
    if content is None:
        content = ""
    content = content.strip()
    
    # Reasoning models put full response in 'reasoning' field, 
    # final answer in 'content'. Use both.
    reasoning = msg.get("reasoning", "") or ""
    
    # For answer extraction, prefer content (has the final answer)
    # But if empty, search reasoning too
    full = content
    if reasoning:
        full = content + "\n" + reasoning.strip()
    
    return full

# ── GSM8K evaluation ──────────────────────────────────────────

GSM8K_EXAMPLES = [
    ("Janet has 3 apples. She buys 5 more. How many apples does she have?",
     "Janet starts with 3 apples. She buys 5 more. So 3 + 5 = 8. The answer is 8."),
    ("There are 15 birds in a tree. 7 fly away. How many are left?",
     "There were 15 birds. 7 fly away. So 15 - 7 = 8. The answer is 8."),
    ("A baker bakes 24 cookies. He puts them into 6 boxes equally. How many cookies in each box?",
     "24 cookies divided into 6 boxes = 24 / 6 = 4 cookies per box. The answer is 4."),
    ("Tom has $50. He buys a book for $12 and a pen for $3. How much money does he have left?",
     "Tom had $50. He spent $12 + $3 = $15. So $50 - $15 = $35. The answer is 35."),
    ("A train travels at 60 miles per hour. How far does it travel in 3 hours?",
     "Distance = speed × time = 60 × 3 = 180 miles. The answer is 180."),
]

GSM8K_TEST = [
    ("John has 12 marbles. He gives 4 to his sister and loses 2. How many does he have left?", 6),
    ("A pizza has 8 slices. 3 people each eat 2 slices. How many slices are left?", 2),
    ("Sarah reads 15 pages on Monday, 22 on Tuesday, and 18 on Wednesday. How many pages total?", 55),
    ("A farmer has 36 eggs. He packs them into cartons of 6. How many cartons does he need?", 6),
    ("A store sells apples for $2 each. Lisa buys 3 apples and pays with a $10 bill. How much change?", 4),
    ("There are 48 students in a school. Each class has 24 students. How many classes are there?", 2),
    ("Mike runs 5 kilometers each day. How many kilometers does he run in 2 weeks (14 days)?", 70),
    ("A box contains 120 candies. If 8 children share them equally, how many does each get?", 15),
    ("Emma has $75. She spends $28 on a dress and $15 on shoes. How much does she have left?", 32),
    ("A garden has 6 rows of carrots with 12 carrots in each row. 15 carrots are picked. How many remain?", 57),
]

def extract_answer(text):
    """Extract numeric answer from model response."""
    # Try "Final answer: X" or "final answer: X" (reasoning model pattern)
    m = re.search(r'[Ff]inal answer[\s.:-]*\s*(\d+(?:\.\d+)?)', text)
    if m:
        return float(m.group(1))
    # Try "The answer is X" pattern
    m = re.search(r'[Tt]he answer is\s*[:\-]?\s*(\d+(?:\.\d+)?)', text)
    if m:
        return float(m.group(1))
    # Try "Answer: X" pattern
    m = re.search(r'[Aa]nswer\s*[:\-]?\s*(\d+(?:\.\d+)?)', text)
    if m:
        return float(m.group(1))
    # Try "= X" at end of a line (format used in step-by-step)
    m = re.search(r'=\s*(\d+(?:\.\d+)?)\s*$', text, re.MULTILINE)
    if m:
        return float(m.group(1))
    # Try final number in last non-empty line
    lines = [l.strip() for l in text.split('\n') if l.strip()]
    for line in reversed(lines):
        # Skip lines that contain questions
        if '?' in line:
            continue
        m = re.search(r'(\d+(?:\.\d+)?)', line)
        if m:
            return float(m.group(1))
    return None

def build_gsm8k_prompt(question):
    """Build few-shot prompt for GSM8K."""
    prompt = ""
    for q, a in GSM8K_EXAMPLES:
        prompt += f"Question: {q}\nAnswer: {a}\n\n"
    prompt += f"Question: {question}\nAnswer:"
    return prompt

def eval_gsm8k():
    """Run GSM8K evaluation."""
    print("\n" + "=" * 60)
    print("GSM8K EVALUATION (10 samples)")
    print("=" * 60)
    
    correct = 0
    total = len(GSM8K_TEST)
    
    for i, (question, expected) in enumerate(GSM8K_TEST):
        prompt = build_gsm8k_prompt(question)
        
        try:
            response = query(prompt, max_tokens=1024)
            pred = extract_answer(response)
            
            matched = abs(pred - expected) < 0.01 if pred is not None else False
            if matched:
                correct += 1
            
            status = "✅" if matched else "❌"
            print(f"\n  [{i+1}/{total}] {status} {question}")
            print(f"  Expected: {expected}, Got: {pred}")
            print(f"  Response: {response[:150]}...")
        except Exception as e:
            print(f"  [{i+1}/{total}] ERROR: {e}")
    
    accuracy = correct / total * 100
    print(f"\n  ✅ Accuracy: {correct}/{total} ({accuracy:.1f}%)")
    return {"gsm8k": {"accuracy": accuracy, "correct": correct, "total": total}}

# ── Simple Knowledge QA ───────────────────────────────────────

KNOWLEDGE_QA = [
    "What is the capital of France?",
    "What is the chemical symbol for water?",
    "Who wrote the novel '1984'?",
    "What is the speed of light in vacuum? (in km/s)",
    "What year did World War II end?",
]

def eval_knowledge():
    """Run simple knowledge QA."""
    print("\n" + "=" * 60)
    print("KNOWLEDGE QA (5 samples)")
    print("=" * 60)
    
    expected = ["paris", "h2o", "george orwell", "299,792", "1945"]
    
    for i, (question, expect) in enumerate(zip(KNOWLEDGE_QA, expected)):
        try:
            response = query(question, max_tokens=256)
            resp_lower = response.lower()
            match = expect.lower() in resp_lower
            status = "✅" if match else "❌"
            print(f"\n  [{i+1}/5] {status} Q: {question}")
            print(f"  A: {response[:100]}")
            print(f"  Expected: {expect}")
        except Exception as e:
            print(f"  [{i+1}/5] ERROR: {e}")

# ── Reasoning Test ────────────────────────────────────────────

REASONING = [
    ("Logical Deduction", 
     "If all A are B, and all B are C, are all A necessarily C?"),
    ("Common Sense", 
     "A ball is thrown straight up in the air. After 5 seconds, is it still going up, coming down, or could it be either? Explain."),
]

def eval_reasoning():
    """Run simple reasoning tests."""
    print("\n" + "=" * 60)
    print("REASONING TESTS")
    print("=" * 60)
    
    for name, question in REASONING:
        try:
            response = query(question, max_tokens=512)
            print(f"\n  [{name}]:")
            print(f"  Q: {question}")
            print(f"  A: {response[:300]}")
        except Exception as e:
            print(f"  [{name}] ERROR: {e}")

# ── Main ──────────────────────────────────────────────────────

if __name__ == "__main__":
    print("Qwen 3.6-Full Quality Evaluation")
    print("=" * 60)
    print(f"API: {API_URL}")
    print(f"Model: {MODEL}")
    print()
    
    print("Testing API connectivity...")
    try:
        resp = query("Hello", max_tokens=10)
        print(f"✅ API OK (response: '{resp}')")
    except Exception as e:
        print(f"❌ API failed: {e}")
        sys.exit(1)
    
    results = eval_gsm8k()
    eval_knowledge()
    eval_reasoning()
    
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(json.dumps(results, indent=2))
    
    # Save results
    out_path = "results/qwen3.6_quality.json"
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nResults saved to: {out_path}")
