Source code for abstracts_explorer.rag

"""
Retrieval Augmented Generation (RAG) for NeurIPS abstracts.

This module provides RAG functionality to query papers and generate
contextual responses using OpenAI-compatible language model APIs.
"""

import json
import logging
from pathlib import Path
from typing import List, Dict, Optional, Any

from openai import OpenAI

from .config import get_config
from .paper_utils import format_search_results, build_context_from_papers, PaperFormattingError
from .mcp_tools import get_mcp_tools_schema, execute_mcp_tool, format_tool_result_for_llm

logger = logging.getLogger(__name__)


[docs] class RAGError(Exception): """Exception raised for RAG-related errors.""" pass
[docs] class RAGChat: """ RAG chat interface for querying NeurIPS papers. Uses embeddings for semantic search and OpenAI-compatible API for response generation. Optionally integrates with MCP clustering tools to answer questions about conference topics, trends, and developments. Parameters ---------- embeddings_manager : EmbeddingsManager Manager for embeddings and vector search. database : DatabaseManager Database instance for querying paper details. REQUIRED. lm_studio_url : str, optional URL for OpenAI-compatible API, by default "http://localhost:1234" model : str, optional Name of the language model, by default "auto" max_context_papers : int, optional Maximum number of papers to include in context, by default 5 temperature : float, optional Sampling temperature for generation, by default 0.7 enable_mcp_tools : bool, optional Enable MCP clustering tools for topic analysis, by default True Examples -------- >>> from abstracts_explorer.embeddings import EmbeddingsManager >>> from abstracts_explorer.database import DatabaseManager >>> em = EmbeddingsManager() >>> em.connect() >>> db = DatabaseManager() >>> db.connect() >>> chat = RAGChat(em, db) >>> >>> # Ask about specific papers >>> response = chat.query("What are the latest advances in deep learning?") >>> print(response) >>> >>> # Ask about conference topics (uses MCP tools automatically) >>> response = chat.query("What are the main research topics at NeurIPS?") >>> print(response) """
[docs] def __init__( self, embeddings_manager, database, lm_studio_url: Optional[str] = None, model: Optional[str] = None, max_context_papers: Optional[int] = None, temperature: Optional[float] = None, enable_mcp_tools: bool = True, ): """ Initialize RAG chat. Parameters are optional and will use values from environment/config if not provided. Parameters ---------- embeddings_manager : EmbeddingsManager Manager for embeddings and vector search. database : DatabaseManager Database instance for querying paper details. REQUIRED - no fallback allowed. lm_studio_url : str, optional URL for OpenAI-compatible API. If None, uses config value. model : str, optional Name of the language model. If None, uses config value. max_context_papers : int, optional Maximum number of papers to include in context. If None, uses config value. temperature : float, optional Sampling temperature for generation. If None, uses config value. enable_mcp_tools : bool, optional Whether to enable MCP clustering tools for the LLM. Default is True. Raises ------ RAGError If required parameters are missing or invalid. """ if embeddings_manager is None: raise RAGError("embeddings_manager is required.") if database is None: raise RAGError("database is required. RAGChat cannot operate without database access.") config = get_config() self.embeddings_manager = embeddings_manager self.database = database self.lm_studio_url = (lm_studio_url or config.llm_backend_url).rstrip("/") self.model = model or config.chat_model self.max_context_papers = max_context_papers or config.max_context_papers self.temperature = temperature or config.chat_temperature self.enable_query_rewriting = config.enable_query_rewriting self.query_similarity_threshold = config.query_similarity_threshold self.enable_mcp_tools = enable_mcp_tools self.conversation_history: List[Dict[str, str]] = [] self.last_search_query: Optional[str] = None # OpenAI client - lazy loaded on first use to avoid API calls during test collection self._openai_client: Optional[OpenAI] = None self._llm_backend_auth_token = config.llm_backend_auth_token
@property def openai_client(self) -> OpenAI: """ Get the OpenAI client, creating it lazily on first access. This lazy loading prevents API calls during test collection. Returns ------- OpenAI Initialized OpenAI client instance. """ if self._openai_client is None: self._openai_client = OpenAI( base_url=f"{self.lm_studio_url}/v1", api_key=self._llm_backend_auth_token or "lm-studio-local" ) return self._openai_client def _rewrite_query(self, user_query: str) -> str: """ Rewrite user query into an effective search query. Uses the LLM to transform conversational questions into optimized search queries, considering conversation history for follow-up questions. Parameters ---------- user_query : str Original user query or question. Returns ------- str Rewritten query optimized for semantic search. Raises ------ RAGError If query rewriting fails. Examples -------- >>> chat = RAGChat(em, db) >>> rewritten = chat._rewrite_query("What about transformers?") >>> print(rewritten) "transformer architecture attention mechanism neural networks" """ # Build system prompt for query rewriting system_prompt = ( "You are a query rewriting assistant. Your task is to rewrite user questions " "into effective search queries for finding relevant research papers. " "Convert conversational questions into keyword-based search queries. " "For follow-up questions, incorporate context from previous conversation. " "Output ONLY the rewritten query, nothing else. " "Keep queries concise (5-15 words) and focused on key concepts." ) # Build messages with conversation history for context messages = [{"role": "system", "content": system_prompt}] # Include recent conversation history for context (last 4 messages) if self.conversation_history: context_history = self.conversation_history[-4:] messages.extend(context_history) # Add current query messages.append( { "role": "user", "content": f"Rewrite this query: {user_query}", } ) try: response = self.openai_client.chat.completions.create( model=self.model, messages=messages, temperature=0.3, # Lower temperature for consistent rewrites max_tokens=100, # Short rewritten queries timeout=30, # Shorter timeout for quick rewriting ) rewritten = response.choices[0].message.content.strip() # Remove any quotes or extra formatting rewritten = rewritten.strip("\"'") logger.info(f"Rewrote query: '{user_query}' -> '{rewritten}'") return rewritten except Exception as e: logger.warning(f"Query rewriting failed: {e}, using original query") return user_query def _should_retrieve_papers(self, rewritten_query: str) -> bool: """ Determine if new papers should be retrieved based on query similarity. Compares the rewritten query with the last search query to avoid redundant retrievals for similar follow-up questions. Parameters ---------- rewritten_query : str The rewritten search query. Returns ------- bool True if papers should be retrieved, False to reuse previous context. Examples -------- >>> chat = RAGChat(em, db) >>> chat._should_retrieve_papers("deep learning networks") True """ # Always retrieve if this is the first query if self.last_search_query is None: return True # Simple similarity check: count common words # For better similarity, could use embeddings, but this is fast and effective query_words = set(rewritten_query.lower().split()) last_words = set(self.last_search_query.lower().split()) # Calculate Jaccard similarity if not query_words or not last_words: return True intersection = query_words & last_words union = query_words | last_words similarity = len(intersection) / len(union) # Retrieve if queries are less similar than threshold should_retrieve = similarity < self.query_similarity_threshold logger.info( f"Query similarity: {similarity:.2f} (threshold: {self.query_similarity_threshold}). " f"{'Retrieving new papers' if should_retrieve else 'Reusing previous context'}" ) return should_retrieve
[docs] def query( self, question: str, n_results: Optional[int] = None, metadata_filter: Optional[Dict[str, Any]] = None, system_prompt: Optional[str] = None, ) -> Dict[str, Any]: """ Query the RAG system with a question. Parameters ---------- question : str User's question. n_results : int, optional Number of papers to retrieve for context. metadata_filter : dict, optional Metadata filter for paper search. system_prompt : str, optional Custom system prompt for the model. Returns ------- dict Dictionary containing: - response: str - Generated response - papers: list - Retrieved papers used as context - metadata: dict - Additional metadata Raises ------ RAGError If query fails. Examples -------- >>> response = chat.query("What is attention mechanism?") >>> print(response["response"]) >>> print(f"Based on {len(response['papers'])} papers") """ try: if n_results is None: n_results = self.max_context_papers # Rewrite query for better semantic search (if enabled) if self.enable_query_rewriting: rewritten_query = self._rewrite_query(question) else: rewritten_query = question logger.info("Query rewriting disabled, using original query") # Check if we should retrieve new papers or reuse previous context should_retrieve = self._should_retrieve_papers(rewritten_query) if self.enable_query_rewriting else True if should_retrieve: # Search for relevant papers using rewritten query logger.info(f"Searching for papers with rewritten query: {rewritten_query}") search_results = self.embeddings_manager.search_similar( rewritten_query, n_results=n_results, where=metadata_filter ) # Store rewritten query for next comparison self.last_search_query = rewritten_query if not search_results["ids"][0]: logger.warning("No relevant papers found") return { "response": "I couldn't find any relevant papers to answer your question.", "papers": [], "metadata": {"n_papers": 0, "rewritten_query": rewritten_query}, } # Format context from papers using shared utility papers = format_search_results(search_results, self.database, include_documents=True) context = build_context_from_papers(papers) # Cache papers for potential reuse self._cached_papers = papers self._cached_context = context else: # Reuse cached papers and context from previous query logger.info("Reusing cached papers from previous query") papers = getattr(self, "_cached_papers", []) context = getattr(self, "_cached_context", "") if not papers: # Fallback: retrieve papers if cache is empty logger.warning("Cache empty, retrieving papers") search_results = self.embeddings_manager.search_similar( rewritten_query, n_results=n_results, where=metadata_filter ) self.last_search_query = rewritten_query if not search_results["ids"][0]: logger.warning("No relevant papers found") return { "response": "I couldn't find any relevant papers to answer your question.", "papers": [], "metadata": {"n_papers": 0, "rewritten_query": rewritten_query}, } # Format context from papers using shared utility papers = format_search_results(search_results, self.database, include_documents=True) context = build_context_from_papers(papers) # Cache papers for potential reuse self._cached_papers = papers self._cached_context = context # Generate response using OpenAI API (uses original question, not rewritten query) logger.info(f"Generating response using {len(papers)} papers as context") response_text = self._generate_response(question, context, system_prompt) # Store in conversation history self.conversation_history.append({"role": "user", "content": question}) self.conversation_history.append({"role": "assistant", "content": response_text}) return { "response": response_text, "papers": papers, "metadata": { "n_papers": len(papers), "model": self.model, "rewritten_query": rewritten_query, "retrieved_new_papers": should_retrieve, }, } except PaperFormattingError as e: raise RAGError(f"Failed to format papers: {str(e)}") from e except Exception as e: raise RAGError(f"Query failed: {str(e)}") from e
[docs] def chat( self, message: str, use_context: bool = True, n_results: Optional[int] = None, ) -> Dict[str, Any]: """ Continue conversation with context awareness. Parameters ---------- message : str User's message. use_context : bool, optional Whether to retrieve papers as context, by default True n_results : int, optional Number of papers to retrieve. Returns ------- dict Dictionary containing response and metadata. Examples -------- >>> response = chat.chat("Tell me more about transformers") >>> print(response["response"]) """ if use_context: return self.query(message, n_results=n_results) else: # Use only conversation history without paper context response_text = self._generate_response(message, "", None) self.conversation_history.append({"role": "user", "content": message}) self.conversation_history.append({"role": "assistant", "content": response_text}) return { "response": response_text, "papers": [], "metadata": {"n_papers": 0, "model": self.model}, }
[docs] def reset_conversation(self): """ Reset conversation history and cached context. Examples -------- >>> chat.reset_conversation() """ self.conversation_history = [] self.last_search_query = None self._cached_papers = [] self._cached_context = "" logger.info("Conversation history and cache reset")
def _generate_response(self, question: str, context: str, system_prompt: Optional[str] = None) -> str: """ Generate response using OpenAI-compatible API with optional MCP tool support. Parameters ---------- question : str User's question. context : str Context from papers. system_prompt : str, optional Custom system prompt. Returns ------- str Generated response. Raises ------ RAGError If generation fails. """ # Default system prompt if system_prompt is None: system_prompt = ( "You are an AI assistant helping researchers find relevant NeurIPS abstracts. " "Use the provided paper abstracts to answer questions accurately and concisely. " "If the papers don't contain enough information to answer the question, suggest a query that might return more relevant results. " "Always cite which papers you're referencing (e.g., 'Paper 1', 'Paper 2'), using local links: <a href='#paper-1'>Paper-1</a>, <a href='#paper-2'>Paper-2</a> etc." ) # Enhance system prompt when MCP tools are enabled if self.enable_mcp_tools: system_prompt += ( "\n\nYou have access to clustering analysis tools that can help answer questions about " "overall conference topics, trends over time, and recent developments. Use these tools when appropriate " "for questions about general themes, topic evolution, or recent research in specific areas." ) # Build messages messages = [{"role": "system", "content": system_prompt}] # Add conversation history (limit to last 10 messages) if self.conversation_history: messages.extend(self.conversation_history[-10:]) # Add current question with context if context: user_message = f"Context from relevant papers:\n\n{context}\n\nQuestion: {question}" else: user_message = question messages.append({"role": "user", "content": user_message}) # Prepare API call parameters api_params = { "model": self.model, "messages": messages, "temperature": self.temperature, "max_tokens": 1000, "timeout": 180, } # Add MCP tools if enabled if self.enable_mcp_tools: api_params["tools"] = get_mcp_tools_schema() api_params["tool_choice"] = "auto" # Let the model decide when to use tools # Call OpenAI-compatible API using OpenAI client try: response = self.openai_client.chat.completions.create(**api_params) # Check if the model wants to use tools if self.enable_mcp_tools and response.choices[0].message.tool_calls: return self._handle_tool_calls(response, messages) # No tool calls, return direct response return response.choices[0].message.content except Exception as e: raise RAGError(f"Failed to generate response: {str(e)}") def _handle_tool_calls(self, initial_response, messages: List[Dict[str, Any]]) -> str: """ Handle tool calls from the LLM response. This method executes requested MCP tools and sends the results back to the LLM to generate a final response. Parameters ---------- initial_response The initial API response containing tool calls messages : list The message history to continue the conversation Returns ------- str Final response from the LLM after tool execution """ assistant_message = initial_response.choices[0].message tool_calls = assistant_message.tool_calls # Add the assistant's tool call message to history tool_calls_data: List[Dict[str, Any]] = [ { "id": tc.id, "type": "function", "function": { "name": tc.function.name, "arguments": tc.function.arguments } } for tc in tool_calls ] messages.append({ "role": "assistant", "content": assistant_message.content, "tool_calls": tool_calls_data }) # Execute each tool call for tool_call in tool_calls: function_name = tool_call.function.name try: function_args = json.loads(tool_call.function.arguments) except json.JSONDecodeError as e: logger.error(f"Failed to parse tool arguments: {e}") # Use empty args as fallback function_args = {} logger.info(f"LLM requested tool: {function_name} with args: {function_args}") # Execute the tool tool_result = execute_mcp_tool(function_name, function_args) # Format the result for better LLM consumption formatted_result = format_tool_result_for_llm(function_name, tool_result) # Add tool result to messages messages.append({ "role": "tool", "tool_call_id": tool_call.id, "name": function_name, "content": formatted_result }) # Get final response from LLM with tool results try: final_response = self.openai_client.chat.completions.create( model=self.model, messages=messages, temperature=self.temperature, max_tokens=1000, timeout=180, ) return final_response.choices[0].message.content except Exception as e: raise RAGError(f"Failed to generate response after tool calls: {str(e)}")
[docs] def export_conversation(self, output_path: Path) -> None: """ Export conversation history to JSON file. Parameters ---------- output_path : Path Path to output JSON file. Examples -------- >>> chat.export_conversation("conversation.json") """ output_path = Path(output_path) with open(output_path, "w") as f: json.dump(self.conversation_history, f, indent=2) logger.info(f"Conversation exported to {output_path}")