Source code for abstracts_explorer.mcp_server

"""
MCP Server for Cluster Analysis
================================

This module provides a Model Context Protocol (MCP) server that exposes
tools for analyzing clustered embeddings. The server enables LLM-based
assistants to answer questions about conference paper topics, trends,
and developments.

Features:
- Get most frequently mentioned topics from clusters
- Analyze topic evolution over years
- Find recent developments in specific topics
- Generate cluster visualizations
"""

import logging
import json
import re
from typing import Any, Dict, List, Optional, Tuple
from collections import Counter
from copy import deepcopy

from mcp.server.fastmcp import FastMCP

from abstracts_explorer.embeddings import EmbeddingsManager
from abstracts_explorer.database import DatabaseManager
from abstracts_explorer.clustering import ClusteringManager
from abstracts_explorer.config import get_config

logger = logging.getLogger(__name__)

# Initialize FastMCP server
mcp = FastMCP("Abstracts Explorer Cluster Analysis")


[docs] class ClusterAnalysisError(Exception): """Exception raised for cluster analysis errors.""" pass
[docs] def load_clustering_data( collection_name: Optional[str] = None, ) -> tuple[ClusteringManager, DatabaseManager]: """ Load clustering data and database. Parameters ---------- collection_name : str, optional Name of the ChromaDB collection Returns ------- tuple[ClusteringManager, DatabaseManager] Clustering manager and database manager instances Raises ------ ClusterAnalysisError If loading fails """ config = get_config() # Use config defaults if not provided collection_name = collection_name or config.collection_name try: # Initialize embeddings manager em = EmbeddingsManager( collection_name=collection_name, ) em.connect() em.create_collection() # Initialize database manager db = DatabaseManager() db.connect() # Initialize clustering manager cm = ClusteringManager(em, db) return cm, db except Exception as e: raise ClusterAnalysisError(f"Failed to load clustering data: {str(e)}") from e
def _apply_cached_cluster_labels( cm: ClusteringManager, cached_results: Dict[str, Any], ) -> None: """ Restore ``cluster_labels``, ``cluster_label_names``, and ``cluster_keywords`` on *cm* from cached clustering results. Parameters ---------- cm : ClusteringManager Clustering manager with embeddings already loaded (``cm.paper_ids`` must be populated). cached_results : dict Cached results dict containing a ``"points"`` list where each element has ``"id"`` (or ``"paper_id"``) and ``"cluster"`` keys. May also contain ``"cluster_labels"`` (cluster name dict) and ``"cluster_keywords"`` (TF-IDF keyword dict). """ import numpy as np point_id_to_cluster: Dict[str, int] = {} for point in cached_results.get("points", []): pid = point.get("id") or point.get("paper_id", "") point_id_to_cluster[pid] = point.get("cluster", -1) current_ids = cm.paper_ids or [] cm.cluster_labels = np.array([point_id_to_cluster.get(pid, -1) for pid in current_ids]) # Restore cluster label names (LLM-generated or TF-IDF-based names) if cached_results.get("cluster_labels"): cm.cluster_label_names = {int(k): v for k, v in cached_results["cluster_labels"].items()} # Restore TF-IDF cluster keywords if cached_results.get("cluster_keywords"): cm.cluster_keywords = {int(k): v for k, v in cached_results["cluster_keywords"].items()}
[docs] def analyze_cluster_topics( cm: ClusteringManager, db: DatabaseManager, cluster_id: int, use_llm: bool = False, ) -> Dict[str, Any]: """ Analyze a single topic (cluster) and return a concise summary. Each cluster represents a conference topic. The returned dictionary is designed to be consumed directly by an LLM — field names use the word *topic* instead of *cluster* so the model does not need to know about the underlying clustering implementation. Parameters ---------- cm : ClusteringManager Clustering manager with loaded data db : DatabaseManager Database manager for paper metadata cluster_id : int Internal cluster ID to analyze use_llm : bool, optional Whether to use LLM for topic extraction (default: False) Returns ------- dict Dictionary containing: - topic: Human-readable topic name (or ``None``) - paper_count: Number of papers in this topic - keywords: Representative keywords for the topic - sample_titles: A few example paper titles """ if cm.cluster_labels is None or cm.paper_ids is None or cm.metadatas is None: raise ClusterAnalysisError("Clustering data not loaded. Call load_embeddings() and cluster() first.") label_names = cm.cluster_label_names or {} cluster_kws = cm.cluster_keywords or {} # Find papers in this cluster cluster_indices = [i for i, label in enumerate(cm.cluster_labels) if label == cluster_id] if not cluster_indices: return { "topic": label_names.get(cluster_id), "paper_count": 0, "keywords": cluster_kws.get(cluster_id, []), "sample_titles": [], } # Extract sample titles sample_titles: list[str] = [] for idx in cluster_indices: if len(sample_titles) >= 5: break title = cm.metadatas[idx].get("title", "") if title: sample_titles.append(title) return { "topic": label_names.get(cluster_id), "paper_count": len(cluster_indices), "keywords": cluster_kws.get(cluster_id, []), "sample_titles": sample_titles, }
def _parse_conference_year(conference: str) -> Tuple[str, Optional[int]]: """ Parse a trailing year from a conference name. LLMs often combine the conference name and year into a single string (e.g. ``"NeurIPS 2025"``). This helper splits them so the conference name matches the database/cache entries which store the name and year separately. Parameters ---------- conference : str Conference name, possibly with a trailing 4-digit year. Returns ------- tuple of (str, int or None) ``(conference_name, year)`` — *year* is ``None`` when no trailing year was found. Examples -------- >>> _parse_conference_year("NeurIPS 2025") ('NeurIPS', 2025) >>> _parse_conference_year("ICLR") ('ICLR', None) """ match = re.match(r"^(.+)\s+(\d{4})$", conference.strip()) if match: return match.group(1), int(match.group(2)) return conference, None def _lookup_clustering_cache( db: "DatabaseManager", config: Any, conference: str, years: Optional[List[int]], ) -> Any: """ Look up pre-computed clustering results from the database cache. Tries an exact match first. When *years* is non-empty and no exact match is found, retries without the year filter so that an all-years cache entry can serve as a fallback. Parameters ---------- db : DatabaseManager Open database connection. config : object Configuration object (needs ``embedding_model``). conference : str Conference name (already parsed, no trailing year). years : list of int or None Year filter. Returns ------- dict or None Cached clustering results, or ``None`` when nothing matches. """ clustering_params: Dict[str, Any] = { "linkage": "ward", "distance_threshold": 150.0, } # Determine single year for column-based lookup cache_year = years[0] if years and len(years) == 1 else None cached = db.get_clustering_cache( embedding_model=config.embedding_model, reduction_method="tsne", n_components=2, clustering_method="agglomerative", n_clusters=None, clustering_params=clustering_params, conference=conference, year=cache_year, ) # Fallback: if per-year cache not found, try the all-years cache if not cached and years: cached = db.get_clustering_cache( embedding_model=config.embedding_model, reduction_method="tsne", n_components=2, clustering_method="agglomerative", n_clusters=None, clustering_params=clustering_params, conference=conference, year=None, ) return cached def _get_conference_topics_for_single_conference( conference: str, years: Optional[List[int]] = None, collection_name: Optional[str] = None, ) -> Dict[str, Any]: """ Retrieve the main research topics for a single conference. Looks up pre-computed clustering results from the database cache and returns a topic-centric summary. Returns an error dict when no cached results exist for the requested conference/year combination. If the *conference* string contains a trailing year (e.g. ``"NeurIPS 2025"``), the year is extracted and merged into *years* so that the cache lookup matches entries stored under the plain conference name. Parameters ---------- conference : str Conference name (e.g. "NeurIPS", "ICLR", or "NeurIPS 2025"). years : list of int, optional Filter by publication years. collection_name : str, optional Name of ChromaDB collection (uses config default if not provided). Returns ------- dict Topics result dict with ``"topic_sizes"`` and ``"topics"`` keys, or an ``"error"`` key if no cache is available. """ config = get_config() collection_name = collection_name or config.collection_name # Parse year from conference name if present (e.g. "NeurIPS 2025" → "NeurIPS", 2025) conference, extracted_year = _parse_conference_year(conference) if extracted_year is not None: if years is None: years = [extracted_year] elif extracted_year not in years: years = sorted(years + [extracted_year]) cm, db = load_clustering_data(collection_name) try: cached = _lookup_clustering_cache(db, config, conference, years) if not cached: return { "error": ( f"No pre-computed clustering data available for conference " f"'{conference}'" + (f" years={years}" if years else "") + ". Run 'abstracts-explorer clustering pre-generate' first." ), } # Reconstruct ClusteringManager state from cached results cm.load_embeddings(conferences=[conference], years=years) _apply_cached_cluster_labels(cm, cached) stats = cm.get_cluster_statistics() # Build topic_sizes with human-readable names, sorted by size descending label_names = cm.cluster_label_names or {} named_sizes = {label_names.get(cid, f"Topic {cid}"): size for cid, size in stats["cluster_sizes"].items()} topic_sizes = dict(sorted(named_sizes.items(), key=lambda x: x[1], reverse=True)) topics = [] for cluster_id in range(stats["n_clusters"]): topic = analyze_cluster_topics(cm, db, cluster_id) topics.append(topic) # Sort topics by paper_count descending topics.sort(key=lambda t: t["paper_count"], reverse=True) return { "conference": conference, "total_papers": stats["total_papers"], "n_topics": stats["n_clusters"], "topic_sizes": topic_sizes, "topics": topics, } finally: cm.embeddings_manager.close() db.close()
[docs] @mcp.tool() def get_conference_topics( conferences: Optional[List[str]] = None, years: Optional[List[int]] = None, collection_name: Optional[str] = None, **kwargs, ) -> str: """ Get the main research topics of a conference. Returns the key research topics covered at the conference, each with a descriptive name, representative keywords, paper count, and example paper titles. A conference must be specified. When multiple conferences are provided, each conference is analyzed individually and results are combined. Parameters ---------- conferences : list of str, optional Conference names (e.g. ["NeurIPS"]). Required – returns an error when not provided. years : list of int, optional Filter by publication years. collection_name : str, optional Name of ChromaDB collection (uses config default if not provided). **kwargs Ignored (for backwards compatibility with old tool schemas). Returns ------- str JSON string containing the conference topics analysis. """ try: if not conferences: return json.dumps( { "error": ( "A conference must be specified for topic analysis. " "Please provide conferences parameter." ) }, indent=2, ) all_results: List[Dict[str, Any]] = [] for conf in conferences: result = _get_conference_topics_for_single_conference( conference=conf, years=years, collection_name=collection_name, ) all_results.append(result) # If only one conference, return its result directly if len(all_results) == 1: return json.dumps(all_results[0], indent=2) # Multiple conferences – combine return json.dumps({"conference_results": all_results}, indent=2) except Exception as e: logger.error(f"Failed to get conference topics: {str(e)}") return json.dumps({"error": str(e)}, indent=2)
[docs] def merge_where_clause_with_conference( where: Optional[Dict[str, Any]], conference: Optional[str], ) -> Optional[Dict[str, Any]]: """ Merge a WHERE clause with a conference filter. This helper function properly combines custom WHERE clauses with conference filters, avoiding duplicates and handling nested operators correctly. Parameters ---------- where : dict, optional Custom WHERE clause from user conference : str, optional Conference name to filter by Returns ------- dict or None Merged WHERE clause, or None if both inputs are None Raises ------ ValueError If WHERE clause is not a dict """ # Validate where parameter if where is not None and not isinstance(where, dict): raise ValueError(f"WHERE clause must be a dict, got {type(where).__name__}") # If no conference, just return a deep copy of WHERE clause (or None) if not conference: return deepcopy(where) if where else None # If no WHERE clause, just return conference filter if not where: return {"conference": conference} # Check if conference already exists anywhere in WHERE clause def has_conference_filter(obj: Any) -> bool: """Recursively check if conference filter exists in nested structure.""" if isinstance(obj, dict): if "conference" in obj: return True # Check nested values for value in obj.values(): if has_conference_filter(value): return True elif isinstance(obj, list): for item in obj: if has_conference_filter(item): return True return False # If conference already in WHERE clause, don't add again - return deep copy if has_conference_filter(where): return deepcopy(where) # Need to merge conference with WHERE clause - use deep copy to prevent mutations where_filter = deepcopy(where) # If WHERE already has $and, append to it if "$and" in where_filter: where_filter["$and"].append({"conference": conference}) else: # Create new $and with existing filter and conference where_filter = {"$and": [where_filter, {"conference": conference}]} return where_filter
[docs] def merge_where_clause_with_years( where: Optional[Dict[str, Any]], years: Optional[List[int]], ) -> Optional[Dict[str, Any]]: """ Merge a WHERE clause with a years filter. This helper function properly combines custom WHERE clauses with a years filter, avoiding duplicates and handling nested operators correctly. Parameters ---------- where : dict, optional Custom WHERE clause from user years : list of int, optional List of years to filter by Returns ------- dict or None Merged WHERE clause, or None if both inputs are None Raises ------ ValueError If WHERE clause is not a dict """ # Validate where parameter if where is not None and not isinstance(where, dict): raise ValueError(f"WHERE clause must be a dict, got {type(where).__name__}") # If no years, just return a deep copy of WHERE clause (or None) if not years: return deepcopy(where) if where else None # convert years to string because ChromaDB metadata is stored as strings years_str: List[str] = [str(y) for y in years] # If no WHERE clause, just return years filter if not where: return {"year": {"$in": years_str}} # Check if year filter already exists anywhere in WHERE clause def has_year_filter(obj: Any) -> bool: """Recursively check if year filter exists in nested structure.""" if isinstance(obj, dict): if "year" in obj: return True # Check nested values for value in obj.values(): if has_year_filter(value): return True elif isinstance(obj, list): for item in obj: if has_year_filter(item): return True return False # If year filter already in WHERE clause, don't add again - return deep copy if has_year_filter(where): return deepcopy(where) # Need to merge years with WHERE clause - use deep copy to prevent mutations where_filter = deepcopy(where) year_filter = {"year": {"$in": years_str}} # If WHERE already has $and, append to it if "$and" in where_filter: where_filter["$and"].append(year_filter) else: # Create new $and with existing filter and year filter where_filter = {"$and": [where_filter, year_filter]} return where_filter
[docs] @mcp.tool() def get_topic_evolution( topic_keywords: str, conferences: Optional[list[str]] = None, start_year: Optional[int] = None, end_year: Optional[int] = None, distance_threshold: float = 1.1, collection_name: Optional[str] = None, ) -> str: """ Analyze how topics have evolved over the years for one or more conferences. For each conference and year in the given range, this tool uses ``EmbeddingsManager.find_papers_within_distance()`` to count how many papers are semantically close to the topic keywords. It also computes the relative percentage of matching papers with respect to the total number of papers for that conference and year. At least one conference must be specified. The chat frontend can use the returned data to generate a plot with plotly.js showing the topic evolution over time. Parameters ---------- topic_keywords : str Keywords describing the topic to analyze (e.g., "transformers attention") conferences : list of str, optional Conference names to analyze (e.g., ["NeurIPS", "ICLR"]). Required – returns an error when not provided. start_year : int, optional Start year for analysis (inclusive) end_year : int, optional End year for analysis (inclusive) distance_threshold : float, optional Maximum Euclidean distance in embedding space to consider papers relevant (default: 1.1). Lower values mean stricter matching. collection_name : str, optional Name of ChromaDB collection Returns ------- str JSON string containing topic evolution analysis with per-conference year_counts, year_relative (percentage), and year_totals. """ try: if not conferences: return json.dumps( { "error": ( "A conference must be specified for topic evolution analysis. " "Please provide conferences parameter." ) }, indent=2, ) config = get_config() collection_name = collection_name or config.collection_name # Initialize embeddings manager em = EmbeddingsManager( collection_name=collection_name, ) em.connect() em.create_collection() # Initialize database db = DatabaseManager() db.connect() logger.info(f"Analyzing topic evolution for: {topic_keywords}") logger.info(f"Conferences: {conferences}") logger.info(f"Distance threshold: {distance_threshold}") # Embed the query once here and reuse for every (conference, year) pair # to avoid redundant LLM API calls. query_embedding = em.generate_embedding(topic_keywords) conference_data: Dict[str, Dict[str, Any]] = {} total_papers = 0 all_years: set[int] = set() for conference in conferences: # Determine year range from database for this conference available_years = db.get_years_for_conference(conference) if start_year is not None: available_years = [y for y in available_years if y >= start_year] if end_year is not None: available_years = [y for y in available_years if y <= end_year] logger.info(f"Conference: {conference}, years: {available_years}") year_counts: Dict[int, int] = {} year_relative: Dict[int, float] = {} year_totals: Dict[int, int] = {} year_distribution: Dict[int, list] = {} for year in available_years: result_data = em.find_papers_within_distance( database=db, query=topic_keywords, distance_threshold=distance_threshold, conferences=[conference], years=[year], query_embedding=query_embedding, ) count = result_data["count"] year_counts[year] = count total_papers += count # Get total papers for this conference+year for relative percentage stats = db.get_stats(year=year, conference=conference) total_for_year = stats["total_papers"] year_totals[year] = total_for_year if total_for_year > 0: year_relative[year] = round((count / total_for_year) * 100, 2) else: year_relative[year] = 0.0 # Collect sample papers (top 3 closest) sample_papers = [] for paper in result_data["papers"][:3]: sample_papers.append( { "title": paper.get("title", ""), "session": paper.get("session", ""), "distance": paper.get("distance"), } ) year_distribution[year] = sample_papers all_years.update(year_counts.keys()) # Sort by year sorted_years = sorted(year_counts.keys()) conference_data[conference] = { "year_counts": dict(sorted(year_counts.items())), "year_relative": dict(sorted(year_relative.items())), "year_totals": dict(sorted(year_totals.items())), "papers_by_year": { year: { "count": year_counts[year], "relative_percent": year_relative[year], "total_for_year": year_totals[year], "sample_papers": year_distribution[year], } for year in sorted_years }, } sorted_all_years = sorted(all_years) # Build result result: Dict[str, Any] = { "topic": topic_keywords, "conferences": conferences, "distance_threshold": distance_threshold, "total_papers": total_papers, "year_range": { "start": min(sorted_all_years) if sorted_all_years else None, "end": max(sorted_all_years) if sorted_all_years else None, }, "conference_data": conference_data, } # Clean up em.close() db.close() return json.dumps(result, indent=2) except Exception as e: logger.error(f"Failed to get topic evolution: {str(e)}") return json.dumps({"error": str(e)}, indent=2)
[docs] @mcp.tool() def search_papers( topic_keywords: str, years: Optional[List[int]] = None, n_results: int = 10, conference: Optional[str] = None, where: Optional[Dict[str, Any]] = None, collection_name: Optional[str] = None, ) -> str: """ Search for papers on a specific topic. This tool searches for the most relevant papers about a topic, optionally filtered by specific years. A conference must be specified. Parameters ---------- topic_keywords : str Keywords describing the topic (e.g., "large language models") years : list of int, optional List of specific years to filter by (e.g., [2024, 2025]). If None, searches all years. n_results : int, optional Number of papers to return (default: 10) conference : str, optional Conference name to filter by (e.g., "NeurIPS", "ICLR"). Required – returns an error when not provided. where : dict, optional Custom ChromaDB WHERE clause for filtering results by metadata. Supports ChromaDB query operators like $eq, $ne, $gt, $gte, $lt, $lte, $in, $nin. Logical operators $and, $or are also supported. Examples: ``{"year": 2025}``, ``{"session": {"$in": ["Oral Session 1", "Oral Session 2"]}}``, ``{"$and": [{"year": {"$gte": 2024}}, {"conference": "NeurIPS"}]}``. Note: If 'conference' parameter is provided, it will be merged with this WHERE clause. collection_name : str, optional Name of ChromaDB collection Returns ------- str JSON string containing search results """ try: if not conference: return json.dumps( { "error": ( "A conference must be specified for paper search. " "Please provide conference parameter." ) }, indent=2, ) config = get_config() collection_name = collection_name or config.collection_name # Initialize embeddings manager em = EmbeddingsManager( collection_name=collection_name, ) em.connect() em.create_collection() # Initialize database db = DatabaseManager() db.connect() # Build metadata filter using helper function try: where_filter = merge_where_clause_with_conference(where, conference) where_filter = merge_where_clause_with_years(where_filter, years) except ValueError as e: logger.error(f"Invalid WHERE clause: {str(e)}") return json.dumps({"error": f"Invalid WHERE clause: {str(e)}"}, indent=2) # Search for papers search_desc = f"papers from {years}" if years else "papers" logger.info(f"Searching for {search_desc} about: {topic_keywords}") if where_filter: logger.info(f"Applying WHERE filter: {where_filter}") if years: logger.info(f"Year filter: {years}") results = em.search_similar( query=topic_keywords, n_results=n_results, where=where_filter, ) # Filter and format results papers = [] if results["ids"] and results["ids"][0]: for idx, paper_id in enumerate(results["ids"][0]): metadata = results["metadatas"][0][idx] papers.append( { "uid": paper_id, "title": metadata.get("title", ""), "authors": metadata.get("authors", []), "year": metadata.get("year"), "conference": metadata.get("conference", ""), "session": metadata.get("session", ""), "abstract": ( results["documents"][0][idx] if "documents" in results and results["documents"][0] else "" ), "relevance_score": 1.0 - results["distances"][0][idx] if "distances" in results else None, } ) if len(papers) >= n_results: break result = { "topic": topic_keywords, "conference": conference, "years_filter": years, "papers_found": len(papers), "papers": papers, } # Clean up em.close() db.close() return json.dumps(result, indent=2) except Exception as e: logger.error(f"Failed to search papers: {str(e)}") return json.dumps({"error": str(e)}, indent=2)
[docs] @mcp.tool() def get_paper_details( title: Optional[str] = None, paper_id: Optional[str] = None, conference: Optional[str] = None, year: Optional[int] = None, limit: int = 5, ) -> str: """ Get detailed information about papers from the database. Use for folow-up questions after searching for papers using semantic search. Returns full paper metadata including authors, URLs, PDF links, session info, keywords, awards, and other details stored in the database. At least one of *title* or *paper_id* must be provided. Parameters ---------- title : str, optional Title or partial title to search for (case-insensitive). paper_id : str, optional Unique paper identifier (uid or original conference/OpenReview ID). When provided, performs an exact lookup and ignores *title*. conference : str, optional Filter results by conference name (e.g., "NeurIPS", "ICLR"). Only applied when searching by *title*. year : int, optional Filter results by publication year. Only applied when searching by *title*. limit : int, optional Maximum number of papers to return when searching by title (default: 5). Returns ------- str JSON string with fields: - ``papers_found`` – number of papers returned - ``papers`` – list of paper dicts, each containing: title, authors (list), abstract, url, paper_pdf_url, poster_image_url, session, room_name, starttime, endtime, poster_position, keywords, award, year, conference, original_id """ if not title and not paper_id: return json.dumps( {"error": "Provide at least one of 'title' or 'paper_id' to look up a paper."}, indent=2, ) try: db = DatabaseManager() db.connect() result_papers: List[Dict[str, Any]] = [] if paper_id: # Exact lookup by uid or original_id (returns at most one paper) paper = db.get_paper_by_original_id_or_uid(paper_id) if paper is not None: result_papers = [paper] if not result_papers and title: # Keyword search on title with optional conference/year filters result_papers = db.search_papers( keyword=title, conference=conference, year=year, limit=limit, ) result = { "papers_found": len(result_papers), "papers": result_papers, } db.close() return json.dumps(result, indent=2, default=str) except Exception as e: logger.error(f"Failed to get paper details: {str(e)}") return json.dumps({"error": str(e)}, indent=2)
[docs] @mcp.tool() def analyze_topic_relevance( topic: str, distance_threshold: float = 1.1, conferences: Optional[list[str]] = None, years: Optional[list[int]] = None, collection_name: Optional[str] = None, ) -> str: """ Analyze the relevance of a topic by counting papers within a specified distance in embedding space. This tool measures topic relevance by finding papers semantically similar to the topic within a specified Euclidean distance threshold. It's useful for identifying how prevalent or relevant a research topic is at a conference. A conference must be specified. Parameters ---------- topic : str The topic or research question to analyze (e.g., "Uncertainty quantification", "Graph neural networks", "Transformer architectures") distance_threshold : float, optional Maximum Euclidean distance in embedding space to consider papers relevant (default: 1.1). Lower values mean stricter matching. Typical range: 0.5-2.0 for normalized embeddings. conferences : list of str, optional Conference names to filter by (e.g., ["NeurIPS", "ICLR"]). Required – returns an error when not provided. years : list of int, optional Filter results to specific years (e.g., [2024, 2025]) collection_name : str, optional Name of ChromaDB collection (uses config default if not provided) Returns ------- str JSON string containing: - topic: The topic analyzed - distance_threshold: Distance threshold applied - total_papers: Number of papers found within distance - total_considered: Total number of filtered papers considered - conferences: Conferences represented (with counts) - years: Years represented (with counts) - sample_papers: Sample of closest papers with titles and distances - relevance_score: Percentage of filtered papers within distance (0-100) Examples -------- Topic: "Uncertainty quantification" Result: 75 papers found within distance 1.1 Interpretation: High relevance - this is a significant topic at the conference Query: "Quantum machine learning" Result: 3 papers found within distance 1.1 Interpretation: Low relevance - emerging or niche topic """ try: if not conferences: return json.dumps( { "error": ( "A conference must be specified for topic relevance analysis. " "Please provide conferences parameter." ) }, indent=2, ) config = get_config() collection_name = collection_name or config.collection_name # Initialize embeddings manager em = EmbeddingsManager( collection_name=collection_name, ) em.connect() em.create_collection() # Initialize database db = DatabaseManager() db.connect() # Find papers within distance logger.info(f"Analyzing relevance for topic: {topic}") logger.info(f"Distance threshold: {distance_threshold}") if conferences: logger.info(f"Filtering by conferences: {conferences}") if years: logger.info(f"Filtering by years: {years}") result_data = em.find_papers_within_distance( database=db, query=topic, distance_threshold=distance_threshold, conferences=conferences, years=years, ) # Analyze results papers = result_data["papers"] total_papers = len(papers) total_considered = result_data.get("total_considered", total_papers) # Count by conference conference_counts: Counter[str] = Counter() year_counts: Counter[int] = Counter() for paper in papers: if paper.get("conference"): conference_counts[paper["conference"]] += 1 if paper.get("year"): year_counts[paper["year"]] += 1 # Calculate relevance score (0-100 scale) # Ratio of papers within distance threshold to total filtered papers if total_considered > 0: relevance_score = (total_papers / total_considered) * 100 else: relevance_score = 0 # Get sample papers (top 5 closest) sample_papers = [] for paper in papers[:5]: sample_papers.append( { "title": paper.get("title", ""), "year": paper.get("year"), "conference": paper.get("conference", ""), "distance": paper.get("distance"), } ) # Build result result = { "topic": topic, "distance_threshold": distance_threshold, "filters": { "conferences": conferences, "years": years, }, "total_papers": total_papers, "total_considered": total_considered, "relevance_score": round(relevance_score, 1), "conferences": dict(sorted(conference_counts.items(), key=lambda x: (-x[1], x[0]))), "years": dict(sorted(year_counts.items(), key=lambda x: x[0])), "sample_papers": sample_papers, "closest_distance": papers[0].get("distance") if papers else None, } # Clean up em.close() db.close() return json.dumps(result, indent=2) except Exception as e: logger.error(f"Failed to analyze topic relevance: {str(e)}") return json.dumps({"error": str(e)}, indent=2)
[docs] @mcp.tool() def get_cluster_visualization( conferences: Optional[List[str]] = None, years: Optional[List[int]] = None, output_path: Optional[str] = None, collection_name: Optional[str] = None, **kwargs, ) -> str: """ Retrieve pre-computed visualization data for clustered embeddings. This tool looks up cached clustering results (pre-generated via CLI) and returns data suitable for visualization. A conference must be specified. When multiple conferences are provided, each conference is looked up individually and results are combined. The chat frontend can use the returned data to generate a plot with plotly.js showing the clusters. Parameters ---------- conferences : list of str, optional Conference names to retrieve clusters for (e.g. ["NeurIPS"]). Required – returns an error when not provided. years : list of int, optional Filter by publication years. output_path : str, optional Path to save visualization JSON file (optional). collection_name : str, optional Name of ChromaDB collection. **kwargs Ignored (for backwards compatibility with old tool schemas). Returns ------- str JSON string containing visualization data with points, clusters, and statistics. """ try: if not conferences: return json.dumps( { "error": ( "A conference must be specified for cluster visualization. " "Please provide conferences parameter." ) }, indent=2, ) config = get_config() collection_name = collection_name or config.collection_name all_points: List[Dict[str, Any]] = [] combined_stats: Dict[str, Any] = {} for conf in conferences: # Parse year from conference name if present (e.g. "NeurIPS 2025") parsed_conf, extracted_year = _parse_conference_year(conf) vis_years = list(years) if years else None if extracted_year is not None: if vis_years is None: vis_years = [extracted_year] elif extracted_year not in vis_years: vis_years = sorted(vis_years + [extracted_year]) cm, db = load_clustering_data(collection_name) try: cached = _lookup_clustering_cache(db, config, parsed_conf, vis_years) finally: cm.embeddings_manager.close() db.close() if not cached: return json.dumps( { "error": ( f"No pre-computed clustering data available for conference " f"'{parsed_conf}'" + (f" years={vis_years}" if vis_years else "") + ". Run 'abstracts-explorer clustering pre-generate' first." ), }, indent=2, ) all_points.extend(cached.get("points", [])) if not combined_stats: combined_stats = cached.get("statistics", {}) else: # Merge stats across conferences combined_stats["n_clusters"] = combined_stats.get("n_clusters", 0) + cached.get("statistics", {}).get( "n_clusters", 0 ) combined_stats["total_papers"] = combined_stats.get("total_papers", 0) + cached.get( "statistics", {} ).get("total_papers", 0) # Export if requested if output_path: import pathlib try: export_data = {"points": all_points, "statistics": combined_stats} pathlib.Path(output_path).write_text(json.dumps(export_data, indent=2)) except OSError as exc: logger.warning(f"Failed to write visualization to {output_path}: {exc}") output_path = None result = { "n_dimensions": 2, "n_points": len(all_points), "statistics": combined_stats, "points": all_points[:1000], # Limit for MCP response size "visualization_saved": output_path is not None, "output_path": output_path if output_path else None, } return json.dumps(result, indent=2) except Exception as e: logger.error(f"Failed to generate cluster visualization: {str(e)}") return json.dumps({"error": str(e)}, indent=2)
[docs] def run_mcp_server( host: str = "127.0.0.1", port: int = 8000, transport: str = "sse", ) -> None: """ Run the MCP server. Parameters ---------- host : str, optional Host address to bind to (default: "127.0.0.1") port : int, optional Port to listen on (default: 8000) transport : str, optional Transport method: 'sse' or 'stdio' (default: 'sse') Examples -------- >>> run_mcp_server(host="0.0.0.0", port=8000) """ logger.info(f"Starting MCP server on {host}:{port} with {transport} transport") if transport == "stdio": # Run with stdio transport (for local CLI integration) import asyncio asyncio.run(mcp.run_stdio_async()) else: # Run with SSE transport (for HTTP integration) mcp.run(host=host, port=port)
if __name__ == "__main__": # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) # Run server run_mcp_server()