import json
import hashlib
from pathlib import Path
from langchain_text_splitters import RecursiveCharacterTextSplitter

# --------------------------------------------------
# CONFIGURATION
# --------------------------------------------------

WEB_FILE = "_manual_scraped_data/manual_raw.json"
DOCX_FILE = "docx_data/docx_raw.json"
COMBINED_FILE = "combined_raw.json"

OUTPUT_DIR = Path("chunked_data")
OUTPUT_DIR.mkdir(exist_ok=True)

CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200

# --------------------------------------------------
# LOAD + MERGE DATA
# --------------------------------------------------

with open(WEB_FILE, "r", encoding="utf-8") as f:
    web = json.load(f)

with open(DOCX_FILE, "r", encoding="utf-8") as f:
    docx = json.load(f)

pages = web + docx

# Save combined (for traceability / debugging)
with open(COMBINED_FILE, "w", encoding="utf-8") as f:
    json.dump(pages, f, indent=2, ensure_ascii=False)

# --------------------------------------------------
# TEXT SPLITTER
# --------------------------------------------------

splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP
)

chunks = []

# --------------------------------------------------
# CHUNKING LOGIC
# --------------------------------------------------

for page in pages:
    text_chunks = splitter.split_text(page["text"])
    total = len(text_chunks)

    for idx, chunk in enumerate(text_chunks):
        chunk_id = hashlib.sha256(
            f"{page['url']}|{idx}|{chunk}".encode("utf-8")
        ).hexdigest()

        chunks.append({
            "chunk_id": chunk_id,
            "text": chunk,
            "metadata": {
                "source_url": page["url"],
                "source_type": page.get("source_type", "web"),
                "chunk_index": idx,
                "total_chunks": total
            }
        })

# --------------------------------------------------
# SAVE OUTPUT
# --------------------------------------------------

output_file = OUTPUT_DIR / "tracwater_chunks.json"

with open(output_file, "w", encoding="utf-8") as f:
    json.dump(chunks, f, indent=2, ensure_ascii=False)

print("----------------------------------")
print(f"Total pages processed: {len(pages)}")
print(f"Total chunks created: {len(chunks)}")
print(f"Saved chunks to: {output_file}")
print("----------------------------------")

# --------------------------------------------------
# SAMPLE VERIFICATION
# --------------------------------------------------

if chunks:
    sample = chunks[0]
    print("\nSample Chunk:")
    print("Chunk ID:", sample["chunk_id"])
    print("Source URL:", sample["metadata"]["source_url"])
    print("Source Type:", sample["metadata"]["source_type"])
    print("Text (first 300 chars):\n")
    print(sample["text"][:300])
