"""
MCP Tools Integration for RAG Chat
==================================
This module provides integration between MCP clustering tools and the RAG chat system.
It converts MCP tool definitions to OpenAI function calling format and handles tool execution.
The integration allows the LLM to automatically decide when to use clustering tools
to answer questions about conference topics, trends, and developments.
"""
import copy
import inspect
import json
import logging
from typing import Callable, Dict, List, Any, Optional
from abstracts_explorer.mcp_server import (
get_conference_topics,
get_topic_evolution,
search_papers,
get_cluster_visualization,
analyze_topic_relevance,
get_paper_details,
)
logger = logging.getLogger(__name__)
def _abbreviate_result(text: str, max_length: int = 200) -> str:
"""
Abbreviate a result string for logging.
Parameters
----------
text : str
The text to abbreviate.
max_length : int
Maximum number of characters to keep (default: 200).
Returns
-------
str
The original text if short enough, otherwise truncated with '…'.
"""
if len(text) <= max_length:
return text
return text[:max_length] + "…"
def _normalize_search_papers_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalise argument shapes produced by LLMs for the ``search_papers`` tool.
LLMs occasionally produce slightly wrong argument shapes, e.g. a singular
``"year"`` key instead of ``"years"``, or a list for scalar string fields.
This function corrects those mismatches so the downstream function call
always receives the expected types.
Normalizations applied:
* ``year`` (int or list) → ``years`` (list of int)
* ``topic_keywords`` as a list → joined string
* ``conference`` as a list → first element string
* ``conferences`` (list, wrong field name) → ``conference`` (str, first element)
Parameters
----------
arguments : dict
Raw arguments dict coming from the LLM / ``execute_mcp_tool`` caller.
Returns
-------
dict
A new dict with normalized argument values.
"""
args = dict(arguments)
# Normalize 'year' → 'years' (LLMs often use singular form)
if "year" in args and "years" not in args:
year_val = args.pop("year")
if isinstance(year_val, list):
args["years"] = year_val
else:
args["years"] = [year_val]
elif "year" in args:
args.pop("year") # 'years' already present; drop the duplicate
# Normalize topic_keywords: list → space-joined string
if "topic_keywords" in args and isinstance(args["topic_keywords"], list):
args["topic_keywords"] = " ".join(str(k) for k in args["topic_keywords"])
# Normalize conference: list → first element string
if "conference" in args and isinstance(args["conference"], list):
args["conference"] = args["conference"][0] if args["conference"] else None
# Normalize 'conferences' (wrong field name) → 'conference' if not already set
if "conferences" in args and "conference" not in args:
conferences_val = args.pop("conferences")
if isinstance(conferences_val, list) and conferences_val:
args["conference"] = conferences_val[0]
elif isinstance(conferences_val, str):
args["conference"] = conferences_val
elif "conferences" in args:
args.pop("conferences") # 'conference' already present; drop duplicate
return args
def _normalize_get_topic_evolution_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalise argument shapes produced by LLMs for the ``get_topic_evolution`` tool.
Parameters
----------
arguments : dict
Raw arguments dict from the LLM.
Returns
-------
dict
A new dict with normalized argument values.
"""
args = dict(arguments)
# Normalize topic_keywords: list → space-joined string
if "topic_keywords" in args and isinstance(args["topic_keywords"], list):
args["topic_keywords"] = " ".join(str(k) for k in args["topic_keywords"])
# Normalize start_year / end_year: list → first element int
for key in ("start_year", "end_year"):
if key in args and isinstance(args[key], list):
args[key] = args[key][0] if args[key] else None
return args
def _normalize_analyze_topic_relevance_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalise argument shapes produced by LLMs for the ``analyze_topic_relevance`` tool.
Parameters
----------
arguments : dict
Raw arguments dict from the LLM.
Returns
-------
dict
A new dict with normalized argument values.
"""
args = dict(arguments)
# Normalize topic: list → space-joined string
if "topic" in args and isinstance(args["topic"], list):
args["topic"] = " ".join(str(k) for k in args["topic"])
# Normalize conference → conferences if wrong field name used
if "conference" in args and "conferences" not in args:
conf_val = args.pop("conference")
if isinstance(conf_val, str):
args["conferences"] = [conf_val]
elif isinstance(conf_val, list):
args["conferences"] = conf_val
return args
def _normalize_get_paper_details_args(arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalise argument shapes produced by LLMs for the ``get_paper_details`` tool.
Normalizations applied:
* ``year`` as a list → first element int
* ``year`` as a string → int
* ``conference`` as a list → first element string
Parameters
----------
arguments : dict
Raw arguments dict from the LLM.
Returns
-------
dict
A new dict with normalized argument values.
"""
args = dict(arguments)
# Normalize year: list → first element
if "year" in args and isinstance(args["year"], list):
args["year"] = args["year"][0] if args["year"] else None
# Normalize year: string → int
if "year" in args and isinstance(args["year"], str):
try:
args["year"] = int(args["year"])
except (ValueError, TypeError):
args["year"] = None
# Normalize conference: list → first element string
if "conference" in args and isinstance(args["conference"], list):
args["conference"] = args["conference"][0] if args["conference"] else None
return args
def _filter_unknown_kwargs(func: Callable, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
Filter out keyword arguments that are not accepted by *func*, logging a
warning for each unexpected key.
This makes MCP tool dispatch tolerant of extra keys that an LLM may send
(e.g. it produces ``{"year": 2025}`` in addition to ``{"years": [2025]}``
after normalisation has already renamed the field).
Parameters
----------
func : callable
The target function whose signature is used to determine valid keys.
kwargs : dict
Keyword arguments intended for *func*.
Returns
-------
dict
A copy of *kwargs* with unrecognised keys removed.
"""
try:
sig = inspect.signature(func)
valid_params = set(sig.parameters.keys())
# If the function accepts **kwargs itself, pass everything through
has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
if has_var_keyword:
return dict(kwargs)
except (ValueError, TypeError):
# If we can't inspect the signature, pass everything through unchanged
return dict(kwargs)
filtered: Dict[str, Any] = {}
for key, value in kwargs.items():
if key in valid_params:
filtered[key] = value
else:
logger.warning(f"Ignoring unknown argument '{key}' for {func.__name__}(); " "this key will be dropped.")
return filtered
MCP_TOOLS_SCHEMA = [
{
"type": "function",
"function": {
"name": "analyze_topic_relevance",
"description": (
"Analyze the relevance of a research topic by counting papers within a specified "
"distance in embedding space. Use this tool when the user asks about: topic relevance, "
"popularity of a research area, how many papers cover a topic, or identifying significant "
"research themes at a conference. A conference must be specified."
),
"parameters": {
"type": "object",
"properties": {
"topic": {
"type": "string",
"description": "The topic or research question to analyze (e.g., 'Uncertainty quantification')",
},
"distance_threshold": {
"type": "number",
"description": "Maximum Euclidean distance to consider papers relevant (default: 1.1)",
},
"conferences": {
"type": "array",
"items": {"type": "string"},
"description": "Conference names to analyze (e.g., ['NeurIPS']). Required.",
},
"years": {
"type": "array",
"items": {"type": "integer"},
"description": "Filter by specific years (e.g., [2024, 2025])",
},
"collection_name": {"type": "string", "description": "Name of ChromaDB collection (optional)"},
},
"required": ["topic", "conferences"],
},
},
},
{
"type": "function",
"function": {
"name": "get_conference_topics",
"description": (
"Get the main research topics of a conference. "
"Use this tool when the user asks about: overall themes, main topics, research areas, "
"or wants to understand what topics are covered in the conference. "
"Returns topic names, representative keywords, paper counts, and example titles. "
"A conference must be specified."
),
"parameters": {
"type": "object",
"properties": {
"conferences": {
"type": "array",
"items": {"type": "string"},
"description": "Conference names (e.g., ['NeurIPS']). Required.",
},
"years": {
"type": "array",
"items": {"type": "integer"},
"description": "Filter by specific years (e.g., [2024, 2025])",
},
"collection_name": {
"type": "string",
"description": "Name of ChromaDB collection (optional, uses config default)",
},
},
"required": ["conferences"],
},
},
},
{
"type": "function",
"function": {
"name": "get_topic_evolution",
"description": (
"Analyze how specific topics have evolved over the years. "
"Use this tool when the user asks about: trends over time, historical development, how a topic has developed,"
"how a topic has changed, or evolution of research areas. At least one conference must be specified. "
"If requested, multiple conferences can be compared in the same analysis."
"The chat frontend can use the returned data to generate a plot with plotly.js showing the topic evolution over time."
),
"parameters": {
"type": "object",
"properties": {
"topic_keywords": {
"type": "string",
"description": "Keywords describing the topic (e.g., 'transformers attention', 'reinforcement learning')",
},
"conferences": {
"type": "array",
"items": {"type": "string"},
"description": "Conference names to analyze (e.g., ['NeurIPS', 'ICLR']). Required.",
},
"start_year": {"type": "integer", "description": "Start year for analysis (inclusive)"},
"end_year": {"type": "integer", "description": "End year for analysis (inclusive)"},
"distance_threshold": {
"type": "number",
"description": (
"Maximum Euclidean distance in embedding space to consider papers "
"relevant (default: 1.1). Lower values mean stricter matching."
),
},
"collection_name": {"type": "string", "description": "Name of ChromaDB collection (optional)"},
},
"required": ["topic_keywords", "conferences"],
},
},
},
{
"type": "function",
"function": {
"name": "search_papers",
"description": (
"Search for papers on a specific topic. "
"Use this tool when the user asks about: papers on a topic, research about something, "
"specific work, or wants to find papers related to a particular area. "
"Can filter by specific years or search all years. A conference must be specified."
),
"parameters": {
"type": "object",
"properties": {
"topic_keywords": {
"type": "string",
"description": "Keywords describing the topic to search for",
},
"years": {
"type": "array",
"items": {"type": "integer"},
"description": "List of specific years to filter by (e.g., [2024, 2025]). If not provided, searches all years.",
},
"n_results": {"type": "integer", "description": "Number of papers to return (default: 10)"},
"conference": {
"type": "string",
"description": "Conference name to search (e.g., 'NeurIPS', 'ICLR'). Required.",
},
"where": {"type": "object", "description": "Custom ChromaDB WHERE clause for filtering"},
"collection_name": {"type": "string", "description": "Name of ChromaDB collection (optional)"},
},
"required": ["topic_keywords", "conference"],
},
},
},
{
"type": "function",
"function": {
"name": "get_cluster_visualization",
"description": (
"Retrieve pre-computed visualization data for clustered embeddings. "
"Use this tool when the user asks for: a visual representation, graphical view, "
"or wants to see clusters displayed. A conference must be specified."
"The chat frontend can use the returned data to generate a plot with plotly.js showing the clusters."
),
"parameters": {
"type": "object",
"properties": {
"conferences": {
"type": "array",
"items": {"type": "string"},
"description": "Conference names to retrieve visualization for (e.g., ['NeurIPS']). Required.",
},
"years": {
"type": "array",
"items": {"type": "integer"},
"description": "Filter by specific years (e.g., [2024, 2025])",
},
"output_path": {"type": "string", "description": "Path to save visualization JSON (optional)"},
"collection_name": {"type": "string", "description": "Name of ChromaDB collection (optional)"},
},
"required": ["conferences"],
},
},
},
{
"type": "function",
"function": {
"name": "get_paper_details",
"description": (
"Get detailed information about papers from the database, including authors, "
"URLs, PDF links, session information, keywords, and awards. "
"Use this tool when the user asks about: who wrote a paper, paper authors, "
"where to find a paper, PDF or poster links, session or room details, "
"paper awards, or any other metadata about a specific paper."
"must specify either paper_id or title to identify the paper."
"Conference + year are optional but can help disambiguate papers with similar titles or multiple versions. "
"Do not use for searching for papers on a topic; use the 'search_papers' tool for that instead. "
"This tool is useful for specific follow-up questions after searching for papers using semantic search."
),
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Title or partial title to search for (case-insensitive)",
},
"paper_id": {
"type": "string",
"description": "Unique paper identifier (uid or original conference/OpenReview ID)",
},
"conference": {
"type": "string",
"description": "Filter by conference name (e.g., 'NeurIPS', 'ICLR')",
},
"year": {
"type": "integer",
"description": "Filter by publication year",
},
"limit": {
"type": "integer",
"description": "Maximum number of papers to return when searching by title (default: 5)",
},
},
"required": [],
},
},
},
]
def _format_topic_relevance_result(data: Dict[str, Any]) -> str:
"""Format topic relevance result for LLM."""
lines = [f"Topic Relevance Analysis for '{data.get('topic', 'unknown')}':\n"]
total = data.get("total_papers", 0)
total_considered = data.get("total_considered", 0)
distance = data.get("distance_threshold", 0)
relevance = data.get("relevance_score", 0)
lines.append(f"Papers found: {total}/{total_considered} within distance {distance}")
lines.append(f"Relevance score: {relevance}/100\n")
if total > 0:
# Show conferences
conferences = data.get("conferences", {})
if conferences:
lines.append("Conferences:")
for conf, count in list(conferences.items())[:5]:
lines.append(f" {conf}: {count} papers")
# Show years
years = data.get("years", {})
if years:
lines.append("\nYears:")
for year, count in sorted(years.items()):
lines.append(f" {year}: {count} papers")
# Show sample papers
sample_papers = data.get("sample_papers", [])
if sample_papers:
lines.append("\nClosest papers:")
for i, paper in enumerate(sample_papers[:3], 1):
title = paper.get("title", "Unknown")
dist = paper.get("distance", 0)
lines.append(f" {i}. {title} (distance: {dist:.3f})")
closest = data.get("closest_distance")
if closest is not None:
lines.append(f"\nClosest paper distance: {closest:.3f}")
else:
lines.append("\nNo papers found matching the topic within the distance threshold.")
return "\n".join(lines)
def _format_conference_topics_result(data: Dict[str, Any]) -> str:
"""Format conference topics result for LLM."""
conference = data.get("conference", "unknown")
total = data.get("total_papers", 0)
n_topics = data.get("n_topics", 0)
lines = [f"Conference Topics for {conference} ({total} papers, {n_topics} topics):\n"]
topics = data.get("topics", [])
for topic in topics[:10]: # Limit to top 10 topics
name = topic.get("topic") or "Unnamed"
paper_count = topic.get("paper_count", 0)
keywords = topic.get("keywords", [])
lines.append(f"\n{name} ({paper_count} papers):")
if keywords:
lines.append(f" Keywords: {', '.join(keywords[:8])}")
return "\n".join(lines)
def _format_topic_evolution_result(data: Dict[str, Any]) -> str:
"""Format topic evolution result for LLM."""
lines = [f"Topic Evolution Analysis for '{data.get('topic', 'unknown')}':\n"]
conference_data = data.get("conference_data", {})
if conference_data:
for conference, cdata in conference_data.items():
lines.append(f"Conference: {conference}")
year_counts = cdata.get("year_counts", {})
year_relative = cdata.get("year_relative", {})
if year_counts:
lines.append(" Papers per year:")
for year, count in sorted(year_counts.items()):
rel = year_relative.get(year, year_relative.get(str(year), 0))
lines.append(f" {year}: {count} papers ({rel}%)")
lines.append("")
total = data.get("total_papers", 0)
lines.append(f"Total papers found: {total}")
return "\n".join(lines)
def _format_search_papers_result(data: Dict[str, Any]) -> str:
"""Format search papers result for LLM."""
lines = [f"Search Results for '{data.get('topic', 'unknown')}':\n"]
papers = data.get("papers", [])
years_filter = data.get("years_filter")
if years_filter:
lines.append(f"Filtered by years: {years_filter}")
lines.append(f"Found {len(papers)} papers:\n")
for i, paper in enumerate(papers[:5], 1): # Top 5 papers
title = paper.get("title", "Unknown")
year = paper.get("year", "")
lines.append(f"{i}. {title} ({year})")
# Add abstract snippet if available
abstract = paper.get("abstract", "")
if abstract:
snippet = abstract[:150] + "..." if len(abstract) > 150 else abstract
lines.append(f" {snippet}")
return "\n".join(lines)
def _format_visualization_result(data: Dict[str, Any]) -> str:
"""Format visualization result for LLM."""
lines = ["Cluster Visualization Data Generated:\n"]
stats = data.get("statistics", {})
n_points = data.get("n_points", 0)
n_dims = data.get("n_dimensions", 0)
lines.append(f"Generated {n_dims}D visualization with {n_points} points")
lines.append(f"Clusters: {stats.get('n_clusters', 0)}")
if data.get("visualization_saved"):
lines.append(f"Saved to: {data.get('output_path')}")
return "\n".join(lines)
def _format_paper_details_result(data: Dict[str, Any]) -> str:
"""Format paper details result for LLM."""
papers = data.get("papers", [])
lines = [f"Paper Details ({len(papers)} found):\n"]
if not papers:
lines.append("No papers found matching the given criteria.")
return "\n".join(lines)
for i, paper in enumerate(papers, 1):
title = paper.get("title", "Unknown")
year = paper.get("year", "")
conference = paper.get("conference", "")
lines.append(f"\n{i}. {title} ({conference} {year})")
authors = paper.get("authors") or []
if authors:
lines.append(f" Authors: {', '.join(authors)}")
url = paper.get("url") or ""
if url:
lines.append(f" URL: {url}")
pdf = paper.get("paper_pdf_url") or ""
if pdf:
lines.append(f" PDF: {pdf}")
session = paper.get("session") or ""
room = paper.get("room_name") or ""
if session or room:
location = f"{session}" + (f" ({room})" if room else "")
lines.append(f" Session: {location}")
keywords = paper.get("keywords") or ""
if keywords:
lines.append(f" Keywords: {keywords}")
award = paper.get("award") or ""
if award:
lines.append(f" Award: {award}")
abstract = paper.get("abstract") or ""
if abstract:
snippet = abstract[:200] + "..." if len(abstract) > 200 else abstract
lines.append(f" Abstract: {snippet}")
return "\n".join(lines)