"""
This module provides functionality to load JSON data into a SQL database.
Supports both SQLite and PostgreSQL backends via SQLAlchemy.
"""
import hashlib
import logging
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from sqlalchemy import create_engine, select, delete, func, or_, and_, text
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError, ProgrammingError, IntegrityError
# Import Pydantic models from plugin framework
from abstracts_explorer.plugin import (
LightweightPaper,
serialize_authors_to_string,
serialize_keywords_to_string,
deserialize_authors_from_string,
deserialize_keywords_from_string,
)
# Import SQLAlchemy models
from abstracts_explorer.db_models import (
Base,
Paper,
EmbeddingsMetadata,
ClusteringCache,
HierarchicalLabelCache,
ValidationData,
ChatDonation,
EvalQAPair,
EvalResult,
)
logger = logging.getLogger(__name__)
[docs]
class DatabaseError(Exception):
"""Exception raised for database operations."""
pass
[docs]
class EmbeddingModelConflictError(DatabaseError):
"""
Raised when the embedding model in imported data differs from the local database.
Attributes
----------
local_model : str
Embedding model currently in the local database.
remote_model : str
Embedding model in the data being imported.
"""
[docs]
def __init__(self, local_model: str, remote_model: str) -> None:
self.local_model = local_model
self.remote_model = remote_model
super().__init__(
f"Embedding model mismatch: local database uses '{local_model}' "
f"but imported data uses '{remote_model}'. Cannot import data "
f"created with a different embedding model."
)
[docs]
def normalize_model_name(name: str) -> str:
"""
Normalize an embedding model name for comparison.
Strips a leading ``alias-`` prefix (case-insensitive) so that, e.g.,
``alias-qwen3-embeddings-8b`` is considered identical to
``qwen3-embeddings-8b``.
Parameters
----------
name : str
Embedding model name.
Returns
-------
str
Normalized model name.
"""
return re.sub(r"^alias-", "", name, flags=re.IGNORECASE)
[docs]
class DatabaseManager:
"""
Manager for SQL database operations using SQLAlchemy.
Supports SQLite and PostgreSQL backends through SQLAlchemy connection URLs.
Database configuration is read from the config file (PAPER_DB variable).
Attributes
----------
database_url : str
SQLAlchemy database URL from configuration.
engine : Engine or None
SQLAlchemy engine instance.
SessionLocal : sessionmaker or None
SQLAlchemy session factory.
_session : Session or None
Active database session if connected.
Examples
--------
>>> # Database configuration comes from config file
>>> db = DatabaseManager()
>>> db.connect()
>>> db.create_tables()
>>> db.close()
"""
[docs]
def __init__(self):
"""
Initialize the DatabaseManager.
Reads database configuration from the config file.
"""
from abstracts_explorer.config import get_config
config = get_config()
self.database_url = config.database_url
self.engine: Optional[Engine] = None
self.SessionLocal: Optional[sessionmaker] = None
self._session: Optional[Session] = None
self.connection = None # Legacy attribute for backward compatibility (always None now)
[docs]
def connect(self) -> None:
"""
Connect to the database.
Creates the database file if it doesn't exist (SQLite only).
Raises
------
DatabaseError
If connection fails.
"""
try:
# Create parent directories for SQLite
if self.database_url.startswith("sqlite:///"):
db_path_str = self.database_url.replace("sqlite:///", "")
db_path = Path(db_path_str)
db_path.parent.mkdir(parents=True, exist_ok=True)
# Create engine with appropriate settings
connect_args = {}
if self.database_url.startswith("sqlite"):
# SQLite-specific settings
connect_args = {"check_same_thread": False}
self.engine = create_engine(
self.database_url,
connect_args=connect_args,
echo=False, # Set to True for SQL debugging
)
# Create session factory
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine,
)
# Create a session
self._session = self.SessionLocal()
# Set legacy connection attribute to provide raw database connection for backward compatibility
# This allows tests to use .connection.cursor()
self._raw_connection = self.engine.raw_connection()
self.connection = self._raw_connection.driver_connection
logger.debug(f"Connected to database: {self._mask_url(self.database_url)}")
except Exception as e:
raise DatabaseError(f"Failed to connect to database: {str(e)}") from e
def _mask_url(self, url: str) -> str:
"""Mask password in URL for logging."""
if "@" in url and ":" in url:
parts = url.split("@")
if len(parts) == 2:
before_at = parts[0]
if "://" in before_at:
protocol_user = before_at.rsplit(":", 1)[0]
return f"{protocol_user}:***@{parts[1]}"
return url
[docs]
def close(self) -> None:
"""
Close the database connection.
Does nothing if not connected.
"""
if self._session:
self._session.close()
self._session = None
if hasattr(self, "_raw_connection") and self._raw_connection:
self._raw_connection.close()
self._raw_connection = None
if self.engine:
self.engine.dispose()
self.engine = None
self.SessionLocal = None
self.connection = None
logger.debug("Database connection closed")
[docs]
def __enter__(self):
"""Context manager entry."""
self.connect()
return self
[docs]
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()
[docs]
@staticmethod
def compute_uid(title: str, original_id: Optional[Union[str, int]], conference: str, year: int) -> str:
"""
Compute a deterministic paper UID from its identifying fields.
The UID is a 16-character hex string derived from a SHA-256 hash
of the concatenated title, original_id, conference and year.
Parameters
----------
title : str
Paper title.
original_id : str, int, or None
Original paper ID from the source (e.g., OpenReview ID).
conference : str
Conference name (e.g., "NeurIPS").
year : int
Conference year.
Returns
-------
str
16-character hex UID.
Examples
--------
>>> DatabaseManager.compute_uid("My Paper", "abc123", "NeurIPS", 2025)
'a1b2c3d4e5f67890'
"""
uid_source = f"{title}:{original_id}:{conference}:{year}"
return hashlib.sha256(uid_source.encode("utf-8")).hexdigest()[:16]
[docs]
def create_tables(self) -> None:
"""
Create database tables for papers and embeddings metadata.
Creates the following tables:
- papers: Main table for paper information with lightweight ML4PS schema
- embeddings_metadata: Metadata about embeddings (model used, creation date)
- clustering_cache: Cache for clustering results
This method is idempotent - it can be called multiple times without error.
Tables are only created if they don't already exist.
Raises
------
DatabaseError
If table creation fails.
"""
if not self.engine:
raise DatabaseError("Not connected to database")
try:
# Create tables only if they don't exist (checkfirst=True is the default)
# This makes the operation idempotent
Base.metadata.create_all(bind=self.engine, checkfirst=True)
logger.debug("Database tables created successfully")
except (OperationalError, ProgrammingError, IntegrityError) as e:
# These exceptions can occur when tables already exist, especially with:
# - Race conditions in concurrent environments
# - PostgreSQL's pg_type_typname_nsp_index constraint
# - SQLite "table already exists" errors
error_msg = str(e).lower()
if any(x in error_msg for x in ["already exists", "duplicate", "pg_type_typname_nsp_index"]):
# Tables already exist - this is fine, just log it
logger.debug(f"Tables already exist (this is normal): {str(e)}")
return
# For other database errors, re-raise with context
raise DatabaseError(f"Failed to create tables: {str(e)}") from e
except Exception as e:
# Catch any other unexpected errors
raise DatabaseError(f"Failed to create tables: {str(e)}") from e
[docs]
def add_paper(self, paper: LightweightPaper) -> Optional[str]:
"""
Add a single paper to the database.
Parameters
----------
paper : LightweightPaper
Validated paper object to insert.
Returns
-------
str or None
The UID of the inserted paper, or None if paper was skipped (duplicate).
Raises
------
DatabaseError
If insertion fails.
Examples
--------
>>> from abstracts_explorer.plugin import LightweightPaper
>>> db = DatabaseManager()
>>> with db:
... db.create_tables()
... paper = LightweightPaper(
... title="Test Paper",
... authors=["John Doe"],
... abstract="Test abstract",
... session="Session 1",
... poster_position="P1",
... year=2025,
... conference="NeurIPS"
... )
... paper_uid = db.add_paper(paper)
>>> print(f"Inserted paper with UID: {paper_uid}")
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Extract validated fields from LightweightPaper
paper_id = paper.original_id if paper.original_id else None
title = paper.title
abstract = paper.abstract
# Handle authors - store as semicolon-separated names
authors_str = serialize_authors_to_string(paper.authors)
# Generate UID as hash from title + conference + year
uid = self.compute_uid(title, paper_id, paper.conference, paper.year)
# Check if paper already exists (by UID)
existing = self._session.execute(select(Paper).where(Paper.uid == uid)).scalar_one_or_none()
if existing:
logger.debug(f"Skipping duplicate paper: {title} (uid: {uid})")
return None
# Handle keywords (could be list or None)
keywords_str = serialize_keywords_to_string(paper.keywords) if paper.keywords else ""
# Use paper's original_id if available
original_id = str(paper.original_id) if paper.original_id else None
# Create Paper ORM object
new_paper = Paper(
uid=uid,
original_id=original_id,
title=title,
authors=authors_str,
abstract=abstract,
session=paper.session,
poster_position=paper.poster_position,
paper_pdf_url=paper.paper_pdf_url,
poster_image_url=paper.poster_image_url,
url=paper.url,
room_name=paper.room_name,
keywords=keywords_str,
starttime=paper.starttime,
endtime=paper.endtime,
award=paper.award,
year=paper.year,
conference=paper.conference,
)
self._session.add(new_paper)
self._session.commit()
return uid
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to add paper: {str(e)}") from e
[docs]
def add_papers(self, papers: List[LightweightPaper]) -> int:
"""
Add multiple papers to the database in a batch.
Parameters
----------
papers : list of LightweightPaper
List of validated paper objects to insert.
Returns
-------
int
Number of papers successfully inserted (excludes duplicates).
Raises
------
DatabaseError
If batch insertion fails.
Examples
--------
>>> from abstracts_explorer.plugin import LightweightPaper
>>> db = DatabaseManager()
>>> with db:
... db.create_tables()
... papers = [
... LightweightPaper(
... title="Paper 1",
... authors=["Author 1"],
... abstract="Abstract 1",
... session="Session 1",
... poster_position="P1",
... year=2025,
... conference="NeurIPS"
... ),
... LightweightPaper(
... title="Paper 2",
... authors=["Author 2"],
... abstract="Abstract 2",
... session="Session 2",
... poster_position="P2",
... year=2025,
... conference="NeurIPS"
... )
... ]
... count = db.add_papers(papers)
>>> print(f"Inserted {count} papers")
"""
if not self._session:
raise DatabaseError("Not connected to database")
inserted_count = 0
skipped_count = 0
for paper in papers:
try:
result = self.add_paper(paper)
if result is not None:
inserted_count += 1
except DatabaseError as e:
skipped_count += 1
logger.warning(f"Skipping paper '{paper.title}': {e}")
if skipped_count:
logger.warning(f"Skipped {skipped_count} paper(s) due to import errors")
logger.debug(f"Successfully inserted {inserted_count} of {len(papers)} papers")
return inserted_count
[docs]
def donate_validation_data(self, paper_priorities: Dict[str, Dict[str, Any]]) -> int:
"""
Store donated paper rating data for validation purposes.
This method accepts anonymized paper ratings from users and stores them
in the validation_data table for improving the service.
Parameters
----------
paper_priorities : Dict[str, Dict[str, Any]]
Dictionary mapping paper UIDs to priority data.
Each priority data dict must contain:
- priority (int): Rating value
- searchTerm (str, optional): Search term associated with the rating
Returns
-------
int
Number of papers successfully donated
Raises
------
ValueError
If paper_priorities is empty or contains invalid data format
DatabaseError
If database operation fails
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... priorities = {
... "abc123": {"priority": 5, "searchTerm": "machine learning"},
... "def456": {"priority": 4, "searchTerm": "deep learning"}
... }
... count = db.donate_validation_data(priorities)
... print(f"Donated {count} papers")
"""
if not paper_priorities:
raise ValueError("No data provided")
if self._session is None:
raise DatabaseError("Not connected to database")
try:
donated_count = 0
for paper_uid, priority_data in paper_priorities.items():
# Validate data format
if not isinstance(priority_data, dict):
raise ValueError("Invalid data format. Expected dict with priority and searchTerm")
priority = priority_data.get("priority", 0)
search_term = priority_data.get("searchTerm", None)
# Create validation data entry
validation_entry = ValidationData(paper_uid=paper_uid, priority=priority, search_term=search_term)
self._session.add(validation_entry)
donated_count += 1
# Commit all changes
self._session.commit()
logger.info(f"Successfully donated {donated_count} papers to validation data")
return donated_count
except ValueError:
# Re-raise validation errors without rollback
raise
except Exception as e:
self._session.rollback()
logger.error(f"Error donating validation data: {e}", exc_info=True)
raise DatabaseError(f"Failed to donate validation data: {e}")
[docs]
def donate_chat_transcript(self, rating: str, transcript: List[Dict[str, str]]) -> int:
"""
Store a donated chat transcript with thumbs up/down feedback.
This method accepts an anonymized chat transcript and a rating from
the user and stores it for improving the chat system.
Parameters
----------
rating : str
User feedback rating, must be 'up' or 'down'.
transcript : List[Dict[str, str]]
List of message dicts, each with 'role' and 'text' keys.
Returns
-------
int
ID of the stored donation entry.
Raises
------
ValueError
If rating is invalid or transcript is empty/malformed.
DatabaseError
If database operation fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... transcript = [
... {"role": "user", "text": "What papers discuss transformers?"},
... {"role": "assistant", "text": "Here are some relevant papers..."}
... ]
... donation_id = db.donate_chat_transcript("up", transcript)
"""
if rating not in ("up", "down"):
raise ValueError(f"Invalid rating: {rating}. Must be 'up' or 'down'.")
if not transcript or not isinstance(transcript, list):
raise ValueError("Transcript must be a non-empty list of messages.")
for msg in transcript:
if not isinstance(msg, dict) or "role" not in msg or "text" not in msg:
raise ValueError("Each message must be a dict with 'role' and 'text' keys.")
if self._session is None:
raise DatabaseError("Not connected to database")
try:
import json
donation = ChatDonation(
rating=rating,
transcript=json.dumps(transcript),
)
self._session.add(donation)
self._session.commit()
logger.info(f"Successfully donated chat transcript (rating={rating})")
return donation.id
except ValueError:
raise
except Exception as e:
self._session.rollback()
logger.error(f"Error donating chat transcript: {e}", exc_info=True)
raise DatabaseError(f"Failed to donate chat transcript: {e}")
[docs]
def get_chat_donations(
self,
limit: Optional[int] = None,
rating: Optional[str] = None,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""
Retrieve donated chat transcripts from the database.
Parameters
----------
limit : int, optional
Maximum number of entries to return. Returns all if None.
rating : str, optional
Filter by rating ('up' or 'down'). Returns all ratings if None.
offset : int, optional
Number of entries to skip for pagination (default: 0).
Returns
-------
list of dict
List of donation dicts, each containing:
- id (int): Entry ID.
- rating (str): 'up' or 'down'.
- transcript (list): Parsed list of message dicts.
- donated_at (datetime): Donation timestamp.
Raises
------
DatabaseError
If the database operation fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... donations = db.get_chat_donations(limit=10, rating='up')
"""
if self._session is None:
raise DatabaseError("Not connected to database")
try:
import json
query = self._session.query(ChatDonation)
if rating is not None:
query = query.filter(ChatDonation.rating == rating)
query = query.order_by(ChatDonation.donated_at.desc())
if offset:
query = query.offset(offset)
if limit is not None:
query = query.limit(limit)
results = []
for entry in query.all():
try:
transcript = json.loads(entry.transcript)
except Exception:
transcript = []
results.append(
{
"id": entry.id,
"rating": entry.rating,
"transcript": transcript,
"donated_at": entry.donated_at,
}
)
return results
except Exception as e:
logger.error(f"Error retrieving chat donations: {e}", exc_info=True)
raise DatabaseError(f"Failed to retrieve chat donations: {e}")
[docs]
def get_chat_donation_stats(self) -> Dict[str, Any]:
"""
Get summary statistics for donated chat transcripts.
Returns
-------
dict
Statistics dict containing:
- total (int): Total number of donations.
- up (int): Number of thumbs-up donations.
- down (int): Number of thumbs-down donations.
- avg_turns (float): Average number of turns per transcript.
Raises
------
DatabaseError
If the database operation fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... stats = db.get_chat_donation_stats()
... print(f"Total donations: {stats['total']}")
"""
if self._session is None:
raise DatabaseError("Not connected to database")
try:
import json
all_entries = self._session.query(ChatDonation).all()
total = len(all_entries)
up = sum(1 for e in all_entries if e.rating == "up")
down = sum(1 for e in all_entries if e.rating == "down")
turn_counts = []
for entry in all_entries:
try:
turns = len(json.loads(entry.transcript))
except Exception:
turns = 0
turn_counts.append(turns)
avg_turns = sum(turn_counts) / total if total > 0 else 0.0
return {"total": total, "up": up, "down": down, "avg_turns": avg_turns}
except Exception as e:
logger.error(f"Error retrieving chat donation stats: {e}", exc_info=True)
raise DatabaseError(f"Failed to retrieve chat donation stats: {e}")
[docs]
def get_validation_data(
self,
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""
Retrieve donated validation (interesting-paper) data from the database.
Parameters
----------
limit : int, optional
Maximum number of entries to return. Returns all if None.
offset : int, optional
Number of entries to skip for pagination (default: 0).
Returns
-------
list of dict
List of validation data dicts, each containing:
- id (int): Entry ID.
- paper_uid (str): Paper UID.
- priority (int): Priority rating.
- search_term (str or None): Associated search term.
- donated_at (datetime): Donation timestamp.
Raises
------
DatabaseError
If the database operation fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... data = db.get_validation_data(limit=20)
"""
if self._session is None:
raise DatabaseError("Not connected to database")
try:
query = self._session.query(ValidationData).order_by(ValidationData.donated_at.desc())
if offset:
query = query.offset(offset)
if limit is not None:
query = query.limit(limit)
return [
{
"id": entry.id,
"paper_uid": entry.paper_uid,
"priority": entry.priority,
"search_term": entry.search_term,
"donated_at": entry.donated_at,
}
for entry in query.all()
]
except Exception as e:
logger.error(f"Error retrieving validation data: {e}", exc_info=True)
raise DatabaseError(f"Failed to retrieve validation data: {e}")
[docs]
def get_validation_data_stats(self) -> Dict[str, Any]:
"""
Get summary statistics for donated validation (interesting-paper) data.
Returns
-------
dict
Statistics dict containing:
- total (int): Total number of donated paper ratings.
- unique_papers (int): Number of distinct paper UIDs.
- avg_priority (float): Average priority rating.
- priority_distribution (dict): Count per priority value.
Raises
------
DatabaseError
If the database operation fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... stats = db.get_validation_data_stats()
... print(f"Total data donations: {stats['total']}")
"""
if self._session is None:
raise DatabaseError("Not connected to database")
try:
all_entries = self._session.query(ValidationData).all()
total = len(all_entries)
unique_papers = len({e.paper_uid for e in all_entries})
priorities = [e.priority for e in all_entries]
avg_priority = sum(priorities) / total if total > 0 else 0.0
distribution: Dict[int, int] = {}
for p in priorities:
distribution[p] = distribution.get(p, 0) + 1
return {
"total": total,
"unique_papers": unique_papers,
"avg_priority": avg_priority,
"priority_distribution": distribution,
}
except Exception as e:
logger.error(f"Error retrieving validation data stats: {e}", exc_info=True)
raise DatabaseError(f"Failed to retrieve validation data stats: {e}")
[docs]
def delete_chat_donations(self, ids: Optional[List[int]] = None) -> int:
"""
Delete donated chat transcripts.
Parameters
----------
ids : list of int, optional
List of donation IDs to delete. If None, deletes all donations.
Returns
-------
int
Number of donations deleted.
Raises
------
DatabaseError
If the database operation fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... deleted = db.delete_chat_donations()
... print(f"Deleted {deleted} donations")
"""
if self._session is None:
raise DatabaseError("Not connected to database")
try:
query = self._session.query(ChatDonation)
if ids is not None:
query = query.filter(ChatDonation.id.in_(ids))
count = query.count()
query.delete(synchronize_session=False)
self._session.commit()
logger.info(f"Deleted {count} chat donation(s)")
return count
except Exception as e:
self._session.rollback()
logger.error(f"Error deleting chat donations: {e}", exc_info=True)
raise DatabaseError(f"Failed to delete chat donations: {e}")
[docs]
def delete_validation_data(self, ids: Optional[List[int]] = None) -> int:
"""
Delete donated validation (interesting-paper) data.
Parameters
----------
ids : list of int, optional
List of entry IDs to delete. If None, deletes all validation data.
Returns
-------
int
Number of entries deleted.
Raises
------
DatabaseError
If the database operation fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... deleted = db.delete_validation_data()
... print(f"Deleted {deleted} entries")
"""
if self._session is None:
raise DatabaseError("Not connected to database")
try:
query = self._session.query(ValidationData)
if ids is not None:
query = query.filter(ValidationData.id.in_(ids))
count = query.count()
query.delete(synchronize_session=False)
self._session.commit()
logger.info(f"Deleted {count} validation data entry(ies)")
return count
except Exception as e:
self._session.rollback()
logger.error(f"Error deleting validation data: {e}", exc_info=True)
raise DatabaseError(f"Failed to delete validation data: {e}")
[docs]
def query(self, sql: str, parameters: tuple = ()) -> List[Dict[str, Any]]:
"""
Execute a SQL query and return results.
Note: This method provides backward compatibility with raw SQL queries.
For new code, prefer using SQLAlchemy ORM methods.
Parameters
----------
sql : str
SQL query to execute (use named parameters like :param1, :param2).
parameters : tuple, optional
Query parameters for parameterized queries.
Returns
-------
list of dict
Query results as list of dictionaries.
Raises
------
DatabaseError
If query execution fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... results = db.query("SELECT * FROM papers WHERE session = ?", ("Poster",))
>>> for row in results:
... print(row['title'])
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Convert ? placeholders to :0, :1, :2 for SQLAlchemy
# Count the number of ? placeholders
param_count = sql.count("?")
# Replace ? with numbered parameters
converted_sql = sql
for i in range(param_count):
converted_sql = converted_sql.replace("?", f":param{i}", 1)
# Create parameter dict
param_dict = {f"param{i}": parameters[i] for i in range(len(parameters))}
# Execute raw SQL using text()
result = self._session.execute(text(converted_sql), param_dict)
# Convert result to list of dicts
rows = []
for row in result:
# Convert row to dict
row_dict = dict(row._mapping)
rows.append(row_dict)
return rows
except Exception as e:
raise DatabaseError(f"Query failed: {str(e)}") from e
[docs]
def get_paper_count(self) -> int:
"""
Get the total number of papers in the database.
Returns
-------
int
Number of papers.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
count = self._session.execute(select(func.count()).select_from(Paper)).scalar()
return count or 0
except Exception as e:
raise DatabaseError(f"Failed to count papers: {str(e)}") from e
[docs]
def get_paper_by_uid(self, uid: str) -> Optional[Dict[str, Any]]:
"""
Retrieve a paper by its UID.
Parameters
----------
uid : str
UID of the paper to retrieve.
Returns
-------
dict or None
Paper data as a dictionary, or None if not found.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
paper = self._session.execute(select(Paper).where(Paper.uid == uid)).scalar_one_or_none()
return self._paper_to_dict(paper) if paper else None
except Exception as e:
raise DatabaseError(f"Failed to retrieve paper by UID: {str(e)}") from e
[docs]
def get_paper_by_original_id_or_uid(self, paper_id: str) -> Optional[Dict[str, Any]]:
"""
Retrieve a paper by its UID or original_id (whichever matches first).
Tries uid first, then falls back to original_id. All formatting
(e.g. author deserialization) is performed inside this method so that
callers always receive a fully formatted paper dictionary.
Parameters
----------
paper_id : str
Value to match against the ``uid`` or ``original_id`` column.
Returns
-------
dict or None
Fully formatted paper data dictionary, or None if not found.
Raises
------
DatabaseError
If the database query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
paper = self._session.execute(
select(Paper).where((Paper.uid == paper_id) | (Paper.original_id == paper_id)).limit(1)
).scalar_one_or_none()
return self._paper_to_dict(paper) if paper else None
except Exception as e:
raise DatabaseError(f"Failed to retrieve paper by id/uid: {str(e)}") from e
#: Paper model column names that can be used as ``field:"value"`` filters
#: in search queries. Internal columns (``uid``, ``created_at``) are
#: excluded because they are not meaningful search targets for users.
SEARCHABLE_FIELDS: set = {c.name for c in Paper.__table__.columns if c.name not in ("uid", "created_at")}
#: Aliases for field names in search queries. An alias is transparently
#: resolved to the canonical column name before applying the filter.
FIELD_ALIASES: Dict[str, str] = {"author": "authors"}
[docs]
def search_papers(
self,
keyword: Optional[str] = None,
field_filters: Optional[Dict[str, str]] = None,
session: Optional[str] = None,
sessions: Optional[List[str]] = None,
year: Optional[int] = None,
years: Optional[List[int]] = None,
conference: Optional[str] = None,
conferences: Optional[List[str]] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""
Search for papers by various criteria (lightweight schema).
Parameters
----------
keyword : str, optional
Keyword to search in title, abstract, or keywords fields.
field_filters : dict of str to str, optional
Mapping of Paper column names to search values. Each entry adds
a case-insensitive ILIKE ``%value%`` condition on the corresponding
column (e.g. ``{"authors": "Smith", "award": "Best Paper"}``).
session : str, optional
Single session to filter by (deprecated, use sessions instead).
sessions : list[str], optional
List of sessions to filter by (matches ANY).
year : int, optional
Single year to filter by (deprecated, use years instead).
years : list[int], optional
List of years to filter by (matches ANY).
conference : str, optional
Single conference to filter by (deprecated, use conferences instead).
conferences : list[str], optional
List of conferences to filter by (matches ANY).
limit : int, default=100
Maximum number of results to return.
Returns
-------
list of dict
Matching papers as dictionaries.
Raises
------
DatabaseError
If search fails.
Examples
--------
>>> db = DatabaseManager("neurips.db")
>>> with db:
... papers = db.search_papers(keyword="neural network", limit=10)
>>> for paper in papers:
... print(paper['title'])
>>> # Search with multiple sessions
>>> papers = db.search_papers(sessions=["Session 1", "Session 2"])
>>> # Search with years
>>> papers = db.search_papers(years=[2024, 2025])
>>> # Search by field filter
>>> papers = db.search_papers(field_filters={"authors": "John Smith"})
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Build query conditions
conditions = []
if keyword:
search_pattern = f"%{keyword}%"
conditions.append(
or_(
Paper.title.ilike(search_pattern),
Paper.abstract.ilike(search_pattern),
Paper.keywords.ilike(search_pattern),
)
)
if field_filters:
for field_name, value in field_filters.items():
col = getattr(Paper, field_name, None)
if col is not None:
conditions.append(col.ilike(f"%{value}%"))
# Handle sessions (prefer list form, fall back to single)
session_list = sessions if sessions else ([session] if session else [])
if session_list:
conditions.append(Paper.session.in_(session_list))
# Handle years (prefer list form, fall back to single)
year_list = years if years else ([year] if year else [])
if year_list:
conditions.append(Paper.year.in_(year_list))
# Handle conferences (prefer list form, fall back to single)
conference_list = conferences if conferences else ([conference] if conference else [])
if conference_list:
conditions.append(Paper.conference.in_(conference_list))
# Build query
stmt = select(Paper)
if conditions:
stmt = stmt.where(and_(*conditions))
if limit:
stmt = stmt.limit(limit)
# Execute query
results = self._session.execute(stmt).scalars().all()
# Convert ORM objects to dicts
return [self._paper_to_dict(paper) for paper in results]
except Exception as e:
raise DatabaseError(f"Search failed: {str(e)}") from e
[docs]
@staticmethod
def parse_field_filters(query: str) -> tuple:
"""
Parse field-specific filters from a search query string.
Extracts all ``field:"value"`` patterns from the query where *field*
is a column name of the :class:`~abstracts_explorer.db_models.Paper`
model (or a recognised alias; see :attr:`FIELD_ALIASES`).
Unrecognised field names are left in the query text as-is.
Parameters
----------
query : str
The raw search query, e.g.
``'authors:"John Smith" award:"Best Paper" transformers'``.
Returns
-------
tuple of (dict, str)
A tuple ``(field_filters, remaining_query)`` where
*field_filters* maps canonical column names to their search values
and *remaining_query* is the query with recognised filters removed.
Examples
--------
>>> DatabaseManager.parse_field_filters('authors:"John Smith" transformers')
({'authors': 'John Smith'}, 'transformers')
>>> DatabaseManager.parse_field_filters('author:"John Smith" transformers')
({'authors': 'John Smith'}, 'transformers')
>>> DatabaseManager.parse_field_filters('Author:"John Smith" transformers')
({'authors': 'John Smith'}, 'transformers')
>>> DatabaseManager.parse_field_filters('transformers')
({}, 'transformers')
>>> DatabaseManager.parse_field_filters('award:"Best Paper" authors:"Doe"')
({'award': 'Best Paper', 'authors': 'Doe'}, '')
"""
field_filters: Dict[str, str] = {}
remaining = query
# Build a lower-cased lookup for aliases and searchable fields so
# that user input is matched case-insensitively (e.g. Author, AUTHOR).
alias_lower = {k.lower(): v for k, v in DatabaseManager.FIELD_ALIASES.items()}
fields_lower = {f.lower(): f for f in DatabaseManager.SEARCHABLE_FIELDS}
# Iterate in reverse so that removing matched spans does not
# invalidate the start/end offsets of earlier matches.
for match in reversed(list(re.finditer(r'(\w+):"([^"]+)"', query))):
raw_field = match.group(1).lower()
# Resolve alias (e.g. "author" → "authors") if applicable
field_name = alias_lower.get(raw_field, raw_field)
value = match.group(2).strip()
if field_name in fields_lower:
field_filters[fields_lower[field_name]] = value
remaining = remaining[: match.start()] + remaining[match.end() :]
remaining = " ".join(remaining.split()) # normalise whitespace
return field_filters, remaining
[docs]
def search_papers_keyword(
self,
query: str,
limit: int = 10,
sessions: Optional[List[str]] = None,
years: Optional[List[int]] = None,
conferences: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Perform keyword-based search with filtering and field filter parsing.
This is a convenience method that wraps search_papers and formats
the results for web API consumption, including author parsing.
Supports ``field:"value"`` syntax for any Paper model column, e.g.
``'authors:"John Smith" transformers'`` or ``'award:"Best Paper"'``.
Parameters
----------
query : str
Keyword to search in title, abstract, or keywords fields.
May include ``field:"value"`` filters for any Paper column.
limit : int, optional
Maximum number of results, by default 10
sessions : list of str, optional
Filter by paper sessions
years : list of int, optional
Filter by publication years
conferences : list of str, optional
Filter by conference names
Returns
-------
list of dict
List of paper dictionaries with parsed authors
Examples
--------
>>> papers = db.search_papers_keyword(
... "neural networks",
... limit=5,
... years=[2024, 2025]
... )
>>> papers = db.search_papers_keyword('authors:"John Smith"')
>>> papers = db.search_papers_keyword('award:"Best Paper" transformers')
"""
# Parse field filters from query
field_filters, remaining_query = self.parse_field_filters(query)
# Keyword search in database with multiple filter support
papers = self.search_papers(
keyword=remaining_query if remaining_query else None,
field_filters=field_filters if field_filters else None,
sessions=sessions,
years=years,
conferences=conferences,
limit=limit,
)
return papers
[docs]
def get_stats(self, year: Optional[int] = None, conference: Optional[str] = None) -> Dict[str, Any]:
"""
Get database statistics, optionally filtered by year and conference.
Parameters
----------
year : int, optional
Filter by specific year
conference : str, optional
Filter by specific conference
Returns
-------
dict
Statistics dictionary with:
- total_papers: int - Number of papers matching filters
- year: int or None - Filter year if provided
- conference: str or None - Filter conference if provided
Examples
--------
>>> stats = db.get_stats()
>>> print(f"Total papers: {stats['total_papers']}")
>>> stats_2024 = db.get_stats(year=2024)
>>> print(f"Papers in 2024: {stats_2024['total_papers']}")
"""
# Build WHERE clause for filtered count
conditions: List[str] = []
parameters: List[Any] = []
if year is not None:
conditions.append("year = ?")
parameters.append(year)
if conference is not None:
conditions.append("conference = ?")
parameters.append(conference)
if conditions:
where_clause = " AND ".join(conditions)
result = self.query(f"SELECT COUNT(*) as count FROM papers WHERE {where_clause}", tuple(parameters))
total_papers = result[0]["count"] if result else 0
else:
total_papers = self.get_paper_count()
return {
"total_papers": total_papers,
"year": year,
"conference": conference,
}
def _paper_to_dict(self, paper: Paper) -> Dict[str, Any]:
"""Convert Paper ORM object to dictionary."""
return {
"uid": paper.uid,
"original_id": paper.original_id,
"title": paper.title,
"authors": deserialize_authors_from_string(paper.authors),
"abstract": paper.abstract,
"session": paper.session,
"poster_position": paper.poster_position,
"paper_pdf_url": paper.paper_pdf_url,
"poster_image_url": paper.poster_image_url,
"url": paper.url,
"room_name": paper.room_name,
"keywords": deserialize_keywords_from_string(paper.keywords),
"starttime": str(paper.starttime), # Convert datetime to string for JSON serialization
"endtime": str(paper.endtime), # Convert datetime to string for JSON serialization
"award": paper.award,
"year": paper.year,
"conference": paper.conference,
"created_at": str(paper.created_at), # Convert datetime to string for JSON serialization
}
[docs]
def get_author_count(self) -> int:
"""
Get the approximate number of unique authors in the database.
Note: This provides an estimate by counting unique author names
across all papers. The actual count may vary.
Returns
-------
int
Approximate number of unique authors.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Get all author fields
stmt = select(Paper.authors).where(and_(Paper.authors.isnot(None), Paper.authors != ""))
results = self._session.execute(stmt).scalars().all()
# Extract unique author names
author_names = set()
for authors_str in results:
if authors_str:
for author in authors_str.split(";"):
author_names.add(author.strip())
return len(author_names)
except Exception as e:
raise DatabaseError(f"Failed to count authors: {str(e)}") from e
[docs]
def get_years_for_conference(self, conference: str) -> List[int]:
"""
Get distinct years available for a specific conference.
Parameters
----------
conference : str
Conference name to query.
Returns
-------
list of int
Sorted list of distinct years for the conference.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = (
select(Paper.year)
.distinct()
.where(and_(Paper.conference == conference, Paper.year.isnot(None)))
.order_by(Paper.year)
)
return list(self._session.execute(stmt).scalars().all())
except Exception as e:
raise DatabaseError(f"Failed to get years for conference: {str(e)}") from e
[docs]
def get_conference_years_from_db(self) -> Dict[str, List[int]]:
"""
Return a mapping of each conference to the years that have papers in the database.
Years are sorted in descending order (most recent first).
Returns
-------
dict[str, list[int]]
Mapping of conference name to list of years (descending) that have
at least one paper in the database. Returns an empty dict when the
database is empty or not connected.
Raises
------
DatabaseError
If the query fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... mapping = db.get_conference_years_from_db()
>>> print(mapping)
{'NeurIPS': [2025, 2024], 'ICLR': [2024]}
"""
if not self._session:
return {}
try:
conferences_in_db = self.get_conferences()
result: Dict[str, List[int]] = {}
for conf in conferences_in_db:
years = self.get_years(conference=conf)
if years:
result[conf] = years # already sorted descending
return result
except Exception as e:
raise DatabaseError(f"Failed to get conference years from DB: {str(e)}") from e
[docs]
def resolve_default_conference_year(
self,
configured_conference: str,
configured_year: Optional[int],
) -> tuple[str, Optional[int]]:
"""
Resolve the effective default conference and year, guaranteeing they have data.
The configured values are used when they point at a conference/year that
actually has papers in the database. When they do not, the method falls
back to the most recent conference/year combination present in the database.
Parameters
----------
configured_conference : str
Conference name from the application configuration. May be empty
or may not match any downloaded conference.
configured_year : int or None
Year from the application configuration. May be ``None`` or may
not have data for the matched conference.
Returns
-------
tuple[str, int | None]
A ``(conference, year)`` pair that is guaranteed to have data in the
database, or the original configured values if the database is empty.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... conf, year = db.resolve_default_conference_year("NeurIPS", 2024)
>>> print(conf, year)
NeurIPS 2024
"""
db_conferences = self.get_conferences()
if not db_conferences:
# DB is empty – return configured values unchanged
return configured_conference, configured_year
# Try case-insensitive match of the configured conference against DB conferences
conf_matched = None
if configured_conference:
for db_conf in db_conferences:
if db_conf.lower() == configured_conference.lower():
conf_matched = db_conf
break
if conf_matched:
effective_conf = conf_matched
years_for_conf = self.get_years(conference=conf_matched)
if configured_year and configured_year in years_for_conf:
effective_year: Optional[int] = configured_year
elif years_for_conf:
# Configured year not in DB for this conference – use the most recent one
effective_year = years_for_conf[0]
else:
effective_year = configured_year
else:
# Configured default has no data (or was not set) – fall back to the
# conference/year combination with the most recent year in the database.
best_conf: Optional[str] = None
best_year: Optional[int] = None
for conf in db_conferences:
years = self.get_years(conference=conf)
if years:
most_recent = years[0] # already sorted descending
if best_year is None or most_recent > best_year:
best_year = most_recent
best_conf = conf
effective_conf = best_conf or configured_conference
effective_year = best_year if best_conf else configured_year
return effective_conf, effective_year
[docs]
def resolve_conference_name(self, conference: str) -> str:
"""
Resolve a conference name to the canonical form stored in the database.
Performs a case-insensitive match against conference names already
present in the database. If no database match is found, falls back
to a case-insensitive match against the ``conference_name`` attribute of
every registered downloader plugin. If neither lookup succeeds the
original *conference* string is returned unchanged.
This is the single authoritative place where conference-name
normalization must happen. CLI commands should call this method
**once** at the entry point of each command and then work with the
returned canonical name for all subsequent operations.
Parameters
----------
conference : str
Conference name as supplied by the caller. May differ in case or
spelling from the form stored in the database (e.g. ``ml4ps@neurips``
vs. ``ML4PS@Neurips``).
Returns
-------
str
The conference name exactly as it appears in the database (first
match), or exactly as defined by the first matching plugin, or the
input string if no match is found.
Examples
--------
>>> with DatabaseManager() as db:
... db.create_tables()
... canonical = db.resolve_conference_name("ml4ps@neurips")
... # Returns "ML4PS@Neurips" if that form is stored in the DB
"""
if not self._session:
raise DatabaseError("Not connected to database")
# 1. Try case-insensitive match against conferences already in the DB
try:
for conf in self.get_conferences():
if conf.lower() == conference.lower():
return conf
except Exception:
pass
# 2. Fall back to plugin conference names
try:
from abstracts_explorer.plugins import get_all_plugins
for plugin in get_all_plugins():
plugin_conf = getattr(plugin, "conference_name", None)
if plugin_conf and plugin_conf.lower() == conference.lower():
return plugin_conf
except Exception:
pass
raise DatabaseError(
f"Failed to resolve conference name: {conference}.\n" f"No match found in database or plugins."
)
[docs]
def resolve_conference_for_url(self, url_path: str) -> dict:
"""
Resolve a URL path segment to a conference, checking data availability.
Combines plugin-based name resolution with a database data check.
Returns a result dict describing the outcome:
- **found with data**: ``{"conference": "<name>", "error": None}``
- **found without data**: ``{"conference": None, "error": {"message": "...", "available_conferences": [...]}}``
- **not found**: ``{"conference": None, "error": {"message": "...", "available_conferences": [...]}}``
Parameters
----------
url_path : str
URL path segment (e.g. ``"neurips"``, ``"ICLR"``).
Returns
-------
dict
Dictionary with keys ``"conference"`` (str or None) and
``"error"`` (dict or None).
Examples
--------
>>> with DatabaseManager() as db:
... result = db.resolve_conference_for_url("neurips")
... if result["conference"]:
... print(f"Found: {result['conference']}")
"""
from abstracts_explorer.plugin import get_available_filters, resolve_conference_from_url
db_conference_years: Dict[str, List[int]] = {}
try:
db_conference_years = self.get_conference_years_from_db()
except Exception:
pass
# Also check DB conferences directly (in case they don't have a plugin)
resolved = resolve_conference_from_url(url_path)
if resolved is None:
for conf in db_conference_years:
if conf.lower() == url_path.lower():
resolved = conf
break
if resolved:
if resolved in db_conference_years:
return {"conference": resolved, "error": None}
else:
available_conferences = sorted(db_conference_years.keys())
return {
"conference": None,
"error": {
"message": f"No data available for conference '{resolved}'. Please download data first.",
"available_conferences": available_conferences,
},
}
else:
plugin_conferences = get_available_filters().get("conferences", [])
available_conferences = sorted(set(list(db_conference_years.keys()) + plugin_conferences))
return {
"conference": None,
"error": {
"message": f"Conference '{url_path}' not found.",
"available_conferences": available_conferences,
},
}
[docs]
def get_sessions(self, conference: Optional[str] = None, year: Optional[int] = None) -> List[str]:
"""
Get distinct session names from the database.
Parameters
----------
conference : str, optional
Filter sessions to only those belonging to this conference.
year : int, optional
Filter sessions to only those belonging to this year.
Returns
-------
list of str
Sorted list of distinct non-empty session names.
Raises
------
DatabaseError
If query fails or not connected.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... sessions = db.get_sessions()
>>> print(sessions)
['Session 1', 'Session 2', ...]
>>> sessions = db.get_sessions(conference="NeurIPS", year=2025)
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
conditions: list = []
if conference is not None:
conditions.append(Paper.conference == conference)
if year is not None:
conditions.append(Paper.year == year)
stmt = select(Paper.session).distinct()
if conditions:
stmt = stmt.where(and_(*conditions))
stmt = stmt.where(and_(Paper.session.isnot(None), Paper.session != "")).order_by(Paper.session)
return list(self._session.execute(stmt).scalars().all())
except Exception as e:
raise DatabaseError(f"Failed to get sessions: {str(e)}") from e
[docs]
def get_conferences(self, year: Optional[int] = None) -> List[str]:
"""
Get distinct conference names from the database.
Parameters
----------
year : int, optional
Filter conferences to only those that have papers for this year.
Returns
-------
list of str
Sorted list of distinct non-empty conference names.
Raises
------
DatabaseError
If query fails or not connected.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... conferences = db.get_conferences()
>>> print(conferences)
['ICLR', 'NeurIPS']
>>> conferences = db.get_conferences(year=2025)
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = (
select(Paper.conference).distinct().where(and_(Paper.conference.isnot(None), Paper.conference != ""))
)
if year is not None:
stmt = stmt.where(Paper.year == year)
stmt = stmt.order_by(Paper.conference)
return list(self._session.execute(stmt).scalars().all())
except Exception as e:
raise DatabaseError(f"Failed to get conferences: {str(e)}") from e
[docs]
def get_years(self, conference: Optional[str] = None) -> List[int]:
"""
Get distinct years from the database, sorted descending (most recent first).
Parameters
----------
conference : str, optional
Filter years to only those that have papers for this conference.
Returns
-------
list of int
Sorted list (descending) of distinct years.
Raises
------
DatabaseError
If query fails or not connected.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... years = db.get_years()
>>> print(years)
[2025, 2024, 2023]
>>> years = db.get_years(conference="NeurIPS")
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = select(Paper.year).distinct().where(Paper.year.isnot(None))
if conference is not None:
stmt = stmt.where(Paper.conference == conference)
stmt = stmt.order_by(Paper.year.desc())
return list(self._session.execute(stmt).scalars().all())
except Exception as e:
raise DatabaseError(f"Failed to get years: {str(e)}") from e
[docs]
def get_embedding_model(self) -> Optional[str]:
"""
Get the embedding model used for the current embeddings.
Returns
-------
str or None
Name of the embedding model, or None if not set.
Raises
------
DatabaseError
If query fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... model = db.get_embedding_model()
>>> print(model)
'text-embedding-qwen3-embedding-4b'
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Get the most recent embedding model entry
stmt = select(EmbeddingsMetadata.embedding_model).order_by(EmbeddingsMetadata.updated_at.desc()).limit(1)
result = self._session.execute(stmt).scalar_one_or_none()
return result
except Exception as e:
raise DatabaseError(f"Failed to get embedding model: {str(e)}") from e
[docs]
def set_embedding_model(self, model_name: str) -> None:
"""
Set the embedding model used for embeddings.
This stores or updates the embedding model metadata. If a record exists,
it updates the model and timestamp. Otherwise, it creates a new record.
Parameters
----------
model_name : str
Name of the embedding model.
Raises
------
DatabaseError
If update fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... db.set_embedding_model("text-embedding-qwen3-embedding-4b")
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Check if any record exists
count_stmt = select(func.count()).select_from(EmbeddingsMetadata)
count = self._session.execute(count_stmt).scalar()
if count and count > 0:
# Update the most recent record
# Get the most recent entry
latest_stmt = select(EmbeddingsMetadata).order_by(EmbeddingsMetadata.updated_at.desc()).limit(1)
latest = self._session.execute(latest_stmt).scalar_one()
latest.embedding_model = model_name
latest.updated_at = datetime.now(timezone.utc)
else:
# Insert new record
new_metadata = EmbeddingsMetadata(embedding_model=model_name)
self._session.add(new_metadata)
self._session.commit()
logger.info(f"Set embedding model to: {model_name}")
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to set embedding model: {str(e)}") from e
[docs]
def get_clustering_cache(
self,
embedding_model: str,
clustering_method: str,
n_clusters: Optional[int] = None,
clustering_params: Optional[Dict[str, Any]] = None,
reduction_method: Optional[str] = None,
n_components: Optional[int] = None,
conference: Optional[str] = None,
year: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
Get cached clustering results matching the parameters.
When ``reduction_method`` and ``n_components`` are provided, only
entries that match exactly (including the reduction method) are
returned. When they are omitted (``None``), the reduction method
is ignored and the most recent entry matching the clustering
parameters is returned.
Parameters
----------
embedding_model : str
Name of the embedding model.
clustering_method : str
Clustering algorithm used.
n_clusters : int, optional
Number of clusters to match. When ``None`` the query does
**not** filter by this column.
clustering_params : dict, optional
Additional clustering parameters (e.g., distance_threshold, eps).
reduction_method : str, optional
Dimensionality reduction method. When provided, the query
requires an exact match on this column.
n_components : int, optional
Number of components after reduction. When provided, the query
requires an exact match on this column.
conference : str, optional
Conference name to filter by. ``None`` matches entries that
have no conference set.
year : int, optional
Conference year to filter by. ``None`` matches entries that
have no year set.
Returns
-------
dict or None
Cached clustering results as dictionary, or None if not found.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
import json
# Build query conditions
stmt = select(ClusteringCache).where(
and_(
ClusteringCache.embedding_model == embedding_model,
ClusteringCache.clustering_method == clustering_method,
)
)
# Optionally filter by reduction method / n_components
if reduction_method is not None:
stmt = stmt.where(ClusteringCache.reduction_method == reduction_method)
if n_components is not None:
stmt = stmt.where(ClusteringCache.n_components == n_components)
# Add n_clusters condition if provided
if n_clusters is not None:
stmt = stmt.where(ClusteringCache.n_clusters == n_clusters)
# Filter by conference/year columns
if conference is not None:
stmt = stmt.where(ClusteringCache.conference == conference)
else:
stmt = stmt.where(ClusteringCache.conference.is_(None))
if year is not None:
stmt = stmt.where(ClusteringCache.year == year)
else:
stmt = stmt.where(ClusteringCache.year.is_(None))
# Get all matching results (we'll filter by params in Python)
stmt = stmt.order_by(ClusteringCache.created_at.desc())
results = self._session.execute(stmt).scalars().all()
if not results:
return None
# When no clustering_params are requested, find the first entry whose
# stored clustering_params is also NULL.
# Entries that have extra params stored (e.g. distance_threshold) are
# skipped here because they represent different clustering runs.
if clustering_params is None:
for result in results:
if result.clustering_params is not None:
continue # entry has extra params – not a match for a no-param query
return json.loads(result.results_json)
return None
# Filter by clustering_params
params_json = json.dumps(clustering_params, sort_keys=True)
for result in results:
if result.clustering_params is None:
continue
# Compare params (normalize by sorting keys)
cached_params = json.loads(result.clustering_params)
cached_params_json = json.dumps(cached_params, sort_keys=True)
if cached_params_json == params_json:
return json.loads(result.results_json)
return None
except Exception as e:
raise DatabaseError(f"Failed to get clustering cache: {str(e)}") from e
[docs]
def save_clustering_cache(
self,
embedding_model: str,
reduction_method: str,
n_components: int,
clustering_method: str,
results: Dict[str, Any],
n_clusters: Optional[int] = None,
clustering_params: Optional[Dict[str, Any]] = None,
conference: Optional[str] = None,
year: Optional[int] = None,
) -> None:
"""
Save clustering results to cache.
The full results including visualization coordinates are stored.
The ``reduction_method`` and ``n_components`` are stored so that
an exact-match lookup can return cached points directly. When only
the reduction method changes, the clustering results are reused and
only the reduction is re-applied.
Parameters
----------
embedding_model : str
Name of the embedding model.
reduction_method : str
Dimensionality reduction method.
n_components : int
Number of components after reduction.
clustering_method : str
Clustering algorithm used.
results : dict
Clustering results to cache (full results including points).
n_clusters : int, optional
Number of clusters. When ``None``, the actual count is
extracted from ``results["statistics"]["n_clusters"]``.
clustering_params : dict, optional
Additional clustering parameters.
conference : str, optional
Conference name this entry is scoped to.
year : int, optional
Conference year this entry is scoped to.
Raises
------
DatabaseError
If save fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
import json
# Always try to fill n_clusters from the results statistics
if n_clusters is None:
stats = results.get("statistics", {})
actual = stats.get("n_clusters")
if actual is not None:
n_clusters = int(actual)
# Serialize results and params to JSON
results_json = json.dumps(results)
params_json = json.dumps(clustering_params) if clustering_params else None
# Create new cache entry
cache_entry = ClusteringCache(
embedding_model=embedding_model,
conference=conference,
year=year,
reduction_method=reduction_method,
n_components=n_components,
clustering_method=clustering_method,
n_clusters=n_clusters,
clustering_params=params_json,
results_json=results_json,
)
self._session.add(cache_entry)
self._session.commit()
logger.info(
f"Saved clustering cache: {clustering_method} with {n_clusters} clusters, "
f"model={embedding_model}, reduction={reduction_method}, "
f"conference={conference}, year={year}"
)
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to save clustering cache: {str(e)}") from e
[docs]
def delete_papers_by_conference_year(self, conference: str, year: int) -> int:
"""
Delete all papers for a specific conference and year combination.
Parameters
----------
conference : str
Conference name (exact, as stored in the database).
year : int
Conference year.
Returns
-------
int
Number of papers deleted.
Raises
------
DatabaseError
If deletion fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
result = self._session.execute(
delete(Paper).where(and_(Paper.conference == conference, Paper.year == year))
)
self._session.commit()
count = result.rowcount if result.rowcount is not None else 0
logger.info(f"Deleted {count} papers for {conference}/{year}")
return count
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to delete papers for {conference}/{year}: {str(e)}") from e
[docs]
def delete_clustering_cache_by_conference_year(self, conference: str, year: int) -> int:
"""
Delete all clustering cache entries for a specific conference and year.
Parameters
----------
conference : str
Conference name (exact, as stored in the database).
year : int
Conference year.
Returns
-------
int
Number of cache entries deleted.
Raises
------
DatabaseError
If deletion fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = select(ClusteringCache).where(
and_(ClusteringCache.conference == conference, ClusteringCache.year == year)
)
entries = self._session.execute(stmt).scalars().all()
count = len(entries)
for entry in entries:
self._session.delete(entry)
self._session.commit()
logger.info(f"Deleted {count} clustering cache entries for {conference}/{year}")
return count
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to delete clustering cache for {conference}/{year}: {str(e)}") from e
[docs]
def count_clustering_cache_by_conference_year(self, conference: str, year: int) -> int:
"""
Count clustering cache entries for a specific conference and year.
Parameters
----------
conference : str
Conference name (exact, as stored in the database).
year : int
Conference year.
Returns
-------
int
Number of cache entries found.
Raises
------
DatabaseError
If the query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
result = self._session.execute(
select(func.count())
.select_from(ClusteringCache)
.where(and_(ClusteringCache.conference == conference, ClusteringCache.year == year))
)
return result.scalar() or 0
except Exception as e:
raise DatabaseError(f"Failed to count clustering cache for {conference}/{year}: {str(e)}") from e
[docs]
def clear_clustering_cache(self, embedding_model: Optional[str] = None) -> int:
"""
Clear clustering cache, optionally filtered by embedding model.
This is useful when embeddings change or cache becomes stale.
Parameters
----------
embedding_model : str, optional
If provided, only clear cache for this embedding model.
If None, clear all cache entries.
Returns
-------
int
Number of cache entries deleted.
Raises
------
DatabaseError
If deletion fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
if embedding_model:
# Delete only for specific model
stmt = select(ClusteringCache).where(ClusteringCache.embedding_model == embedding_model)
else:
# Delete all
stmt = select(ClusteringCache)
entries = self._session.execute(stmt).scalars().all()
count = len(entries)
for entry in entries:
self._session.delete(entry)
self._session.commit()
if embedding_model:
logger.info(f"Cleared {count} clustering cache entries for model: {embedding_model}")
else:
logger.info(f"Cleared all {count} clustering cache entries")
return count
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to clear clustering cache: {str(e)}") from e
[docs]
def update_clustering_cache_embedding_model(self, old_model: str, new_model: str) -> int:
"""
Update the embedding model name in all clustering cache entries.
Renames every ``ClusteringCache`` row whose ``embedding_model``
matches *old_model* so that it refers to *new_model* instead.
This keeps cached clustering results usable after the embedding
model metadata is renamed (e.g. via
``update_embedding_model_metadata.py``).
Parameters
----------
old_model : str
Current embedding model name stored in the cache.
new_model : str
New embedding model name to write.
Returns
-------
int
Number of cache entries updated.
Raises
------
DatabaseError
If the update fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... count = db.update_clustering_cache_embedding_model(
... "old-model", "new-model"
... )
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = select(ClusteringCache).where(ClusteringCache.embedding_model == old_model)
entries = self._session.execute(stmt).scalars().all()
count = len(entries)
for entry in entries:
entry.embedding_model = new_model
self._session.commit()
logger.info(f"Updated {count} clustering cache entries from model {old_model!r} to {new_model!r}")
return count
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to update clustering cache embedding model: {str(e)}") from e
# ------------------------------------------------------------------
# Hierarchical label cache
# ------------------------------------------------------------------
[docs]
def get_hierarchical_label_cache(
self,
embedding_model: str,
linkage: str = "ward",
) -> Optional[Dict[int, str]]:
"""
Get cached hierarchical labels for agglomerative clustering.
Hierarchical labels are independent of the number of clusters and
the distance threshold, so they are reused for all agglomerative
clustering settings that share the same embedding model and linkage.
Parameters
----------
embedding_model : str
Name of the embedding model.
linkage : str, optional
Agglomerative linkage method (default: ``"ward"``).
Returns
-------
dict or None
Mapping of ``{node_id: label}`` (integer keys), or ``None`` if
no entry is found.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
import json
stmt = (
select(HierarchicalLabelCache)
.where(
and_(
HierarchicalLabelCache.embedding_model == embedding_model,
HierarchicalLabelCache.linkage == linkage,
)
)
.order_by(HierarchicalLabelCache.created_at.desc())
.limit(1)
)
result = self._session.execute(stmt).scalars().first()
if result is None:
return None
raw = json.loads(result.labels_json)
# JSON keys are always strings – convert back to int
return {int(k): v for k, v in raw.items()}
except Exception as e:
raise DatabaseError(f"Failed to get hierarchical label cache: {str(e)}") from e
[docs]
def save_hierarchical_label_cache(
self,
embedding_model: str,
labels: Dict[int, str],
linkage: str = "ward",
) -> None:
"""
Save hierarchical cluster labels to cache.
Parameters
----------
embedding_model : str
Name of the embedding model.
labels : dict
Mapping of ``{node_id: label}`` to store.
linkage : str, optional
Agglomerative linkage method (default: ``"ward"``).
Raises
------
DatabaseError
If save fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
import json
labels_json = json.dumps({str(k): v for k, v in labels.items()})
entry = HierarchicalLabelCache(
embedding_model=embedding_model,
linkage=linkage,
labels_json=labels_json,
)
self._session.add(entry)
self._session.commit()
logger.info(
f"Saved hierarchical label cache: {len(labels)} labels, "
f"model={embedding_model}, linkage={linkage}"
)
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to save hierarchical label cache: {str(e)}") from e
# ------------------------------------------------------------------ #
# Evaluation Q/A pair and result methods #
# ------------------------------------------------------------------ #
[docs]
def add_eval_qa_pair(
self,
conversation_id: str,
turn_number: int,
query: str,
expected_answer: str,
tool_name: Optional[str] = None,
source_info: Optional[str] = None,
) -> int:
"""
Insert a single evaluation Q/A pair.
Parameters
----------
conversation_id : str
Identifier grouping turns in a conversation.
turn_number : int
Position within the conversation (0 = first).
query : str
The user query text.
expected_answer : str
The expected/reference answer.
tool_name : str, optional
MCP tool expected to be invoked.
source_info : str, optional
JSON metadata about how the pair was generated.
Returns
-------
int
Primary key of the inserted row.
Raises
------
DatabaseError
If insertion fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
pair = EvalQAPair(
conversation_id=conversation_id,
turn_number=turn_number,
query=query,
expected_answer=expected_answer,
tool_name=tool_name,
source_info=source_info,
)
self._session.add(pair)
self._session.commit()
return pair.id
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to add eval QA pair: {str(e)}") from e
[docs]
def get_eval_qa_pairs(
self,
verified_only: bool = False,
tool_name: Optional[str] = None,
conversation_id: Optional[str] = None,
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""
Retrieve evaluation Q/A pairs with optional filters.
Parameters
----------
verified_only : bool
If ``True``, return only pairs with ``verified == 1``.
tool_name : str, optional
Filter by expected MCP tool name.
conversation_id : str, optional
Filter by conversation.
limit : int, optional
Maximum number of pairs to return.
offset : int
Number of rows to skip (for pagination).
Returns
-------
list of dict
Matching Q/A pairs as dictionaries.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = select(EvalQAPair)
if verified_only:
stmt = stmt.where(EvalQAPair.verified == 1)
if tool_name:
stmt = stmt.where(EvalQAPair.tool_name == tool_name)
if conversation_id:
stmt = stmt.where(EvalQAPair.conversation_id == conversation_id)
stmt = stmt.order_by(EvalQAPair.conversation_id, EvalQAPair.turn_number)
if offset:
stmt = stmt.offset(offset)
if limit:
stmt = stmt.limit(limit)
rows = self._session.execute(stmt).scalars().all()
return [
{
"id": r.id,
"conversation_id": r.conversation_id,
"turn_number": r.turn_number,
"query": r.query,
"expected_answer": r.expected_answer,
"tool_name": r.tool_name,
"verified": r.verified,
"source_info": r.source_info,
"created_at": r.created_at.isoformat() if r.created_at else None,
}
for r in rows
]
except Exception as e:
raise DatabaseError(f"Failed to get eval QA pairs: {str(e)}") from e
[docs]
def get_eval_qa_pair_count(self, verified_only: bool = False) -> int:
"""
Count evaluation Q/A pairs.
Parameters
----------
verified_only : bool
If ``True``, count only verified pairs.
Returns
-------
int
Number of matching pairs.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = select(func.count()).select_from(EvalQAPair)
if verified_only:
stmt = stmt.where(EvalQAPair.verified == 1)
return self._session.execute(stmt).scalar() or 0
except Exception as e:
raise DatabaseError(f"Failed to count eval QA pairs: {str(e)}") from e
[docs]
def update_eval_qa_pair(self, pair_id: int, **fields) -> bool:
"""
Update fields on an existing Q/A pair.
Parameters
----------
pair_id : int
Primary key of the pair to update.
**fields
Keyword arguments mapping column names to new values.
Supported keys: ``query``, ``expected_answer``, ``tool_name``,
``verified``, ``source_info``.
Returns
-------
bool
``True`` if a row was updated, ``False`` if the pair was not found.
Raises
------
DatabaseError
If update fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
allowed = {"query", "expected_answer", "tool_name", "verified", "source_info"}
to_set = {k: v for k, v in fields.items() if k in allowed}
if not to_set:
return False
try:
pair = self._session.get(EvalQAPair, pair_id)
if pair is None:
return False
for k, v in to_set.items():
setattr(pair, k, v)
self._session.commit()
return True
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to update eval QA pair: {str(e)}") from e
[docs]
def delete_eval_qa_pair(self, pair_id: int) -> bool:
"""
Delete an evaluation Q/A pair by ID.
Parameters
----------
pair_id : int
Primary key of the pair to delete.
Returns
-------
bool
``True`` if a row was deleted, ``False`` if pair was not found.
Raises
------
DatabaseError
If deletion fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
pair = self._session.get(EvalQAPair, pair_id)
if pair is None:
return False
self._session.delete(pair)
self._session.commit()
return True
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to delete eval QA pair: {str(e)}") from e
[docs]
def delete_verified_eval_qa_pairs(self) -> int:
"""
Delete all verified (accepted) evaluation Q/A pairs.
Returns
-------
int
Number of pairs deleted.
Raises
------
DatabaseError
If deletion fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
result = self._session.execute(delete(EvalQAPair).where(EvalQAPair.verified == 1))
count = result.rowcount
self._session.commit()
return count
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to delete verified eval QA pairs: {str(e)}") from e
[docs]
def delete_eval_results(self, run_id: Optional[str] = None) -> int:
"""
Delete stored evaluation results, optionally filtered to a single run.
Parameters
----------
run_id : str, optional
If supplied, only results for this run are deleted.
If ``None``, **all** stored results are deleted.
Returns
-------
int
Number of rows deleted.
Raises
------
DatabaseError
If deletion fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = delete(EvalResult)
if run_id is not None:
stmt = stmt.where(EvalResult.run_id == run_id)
result = self._session.execute(stmt)
count = result.rowcount
self._session.commit()
return count
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to delete eval results: {str(e)}") from e
[docs]
def add_eval_result(
self,
run_id: str,
qa_pair_id: int,
actual_answer: Optional[str] = None,
actual_tool_name: Optional[str] = None,
answer_score: Optional[float] = None,
tool_correct: Optional[int] = None,
latency_ms: Optional[int] = None,
error: Optional[str] = None,
judge_reasoning: Optional[str] = None,
) -> int:
"""
Insert a single evaluation result.
Parameters
----------
run_id : str
Identifier for the evaluation run.
qa_pair_id : int
ID of the evaluated Q/A pair.
actual_answer : str, optional
Answer produced by the RAG system.
actual_tool_name : str, optional
MCP tool actually invoked.
answer_score : float, optional
LLM-judged quality score (1–5).
tool_correct : int, optional
1 if the correct tool was used, 0 otherwise.
latency_ms : int, optional
Query latency in milliseconds.
error : str, optional
Error message if the query failed.
judge_reasoning : str, optional
LLM judge's reasoning for the score.
Returns
-------
int
Primary key of the inserted row.
Raises
------
DatabaseError
If insertion fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
result = EvalResult(
run_id=run_id,
qa_pair_id=qa_pair_id,
actual_answer=actual_answer,
actual_tool_name=actual_tool_name,
answer_score=answer_score,
tool_correct=tool_correct,
latency_ms=latency_ms,
error=error,
judge_reasoning=judge_reasoning,
)
self._session.add(result)
self._session.commit()
return result.id
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to add eval result: {str(e)}") from e
[docs]
def get_eval_results(
self,
run_id: Optional[str] = None,
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""
Retrieve evaluation results with optional run filter.
Parameters
----------
run_id : str, optional
Filter by evaluation run. If ``None``, return results from all runs.
limit : int, optional
Maximum number of results to return.
offset : int
Number of rows to skip.
Returns
-------
list of dict
Evaluation results as dictionaries.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = select(EvalResult)
if run_id:
stmt = stmt.where(EvalResult.run_id == run_id)
stmt = stmt.order_by(EvalResult.id)
if offset:
stmt = stmt.offset(offset)
if limit:
stmt = stmt.limit(limit)
rows = self._session.execute(stmt).scalars().all()
return [
{
"id": r.id,
"run_id": r.run_id,
"qa_pair_id": r.qa_pair_id,
"actual_answer": r.actual_answer,
"actual_tool_name": r.actual_tool_name,
"answer_score": r.answer_score,
"tool_correct": r.tool_correct,
"latency_ms": r.latency_ms,
"error": r.error,
"judge_reasoning": r.judge_reasoning,
"created_at": r.created_at.isoformat() if r.created_at else None,
}
for r in rows
]
except Exception as e:
raise DatabaseError(f"Failed to get eval results: {str(e)}") from e
[docs]
def get_eval_run_ids(self) -> List[str]:
"""
Return distinct evaluation run IDs ordered by run time, oldest first.
The ordering is determined by the minimum ``created_at`` timestamp of
all results in each run, so the most recent run appears last.
Returns
-------
list of str
Distinct run IDs ordered chronologically (oldest to newest).
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
stmt = (
select(EvalResult.run_id).group_by(EvalResult.run_id).order_by(func.min(EvalResult.created_at).asc())
)
return [row.run_id for row in self._session.execute(stmt).all()]
except Exception as e:
raise DatabaseError(f"Failed to get eval run IDs: {str(e)}") from e
[docs]
def get_eval_run_summary(self, run_id: str) -> Dict[str, Any]:
"""
Compute summary statistics for an evaluation run.
Parameters
----------
run_id : str
The evaluation run identifier.
Returns
-------
dict
Dictionary with keys:
- total : int – number of evaluated pairs
- avg_score : float or None – mean answer quality score
- tool_accuracy : float or None – fraction of correct tool selections
- avg_latency_ms : float or None – mean latency
- error_count : int – number of queries that produced errors
- run_date : datetime or None – timestamp of the first result in the run
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
results = self.get_eval_results(run_id=run_id)
if not results:
return {
"total": 0,
"avg_score": None,
"tool_accuracy": None,
"avg_latency_ms": None,
"error_count": 0,
"run_date": None,
}
total = len(results)
scores = [r["answer_score"] for r in results if r["answer_score"] is not None]
tool_vals = [r["tool_correct"] for r in results if r["tool_correct"] is not None]
latencies = [r["latency_ms"] for r in results if r["latency_ms"] is not None]
errors = sum(1 for r in results if r["error"])
# Determine the timestamp of the earliest result in this run
stmt = select(func.min(EvalResult.created_at)).where(EvalResult.run_id == run_id)
run_date = self._session.execute(stmt).scalar()
return {
"total": total,
"avg_score": (sum(scores) / len(scores)) if scores else None,
"tool_accuracy": (sum(tool_vals) / len(tool_vals)) if tool_vals else None,
"avg_latency_ms": (sum(latencies) / len(latencies)) if latencies else None,
"error_count": errors,
"run_date": run_date,
}
except Exception as e:
raise DatabaseError(f"Failed to compute eval run summary: {str(e)}") from e
# ------------------------------------------------------------------
# Registry export / import helpers
# ------------------------------------------------------------------
[docs]
def export_papers_to_sqlite(
self,
output_path: "Path",
conference: str,
year: int,
) -> int:
"""
Export papers for a given conference and year to a standalone SQLite file.
The exported file includes hierarchical label cache and embeddings
metadata rows. Clustering cache is **not** included — it is exported
separately via :meth:`export_clustering_cache_to_json`.
Parameters
----------
output_path : Path
Destination path for the SQLite file.
conference : str
Conference name to export.
year : int
Year to export.
Returns
-------
int
Number of papers exported.
Raises
------
DatabaseError
If the export fails or no papers are found.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
export_engine = create_engine(f"sqlite:///{output_path}")
Base.metadata.create_all(export_engine)
paper_count = 0
with Session(export_engine) as export_session:
# Export papers filtered by conference and year
query = select(Paper).where(and_(Paper.conference == conference, Paper.year == year))
for paper in self._session.execute(query).scalars():
paper_dict = {c.name: getattr(paper, c.name) for c in Paper.__table__.columns}
export_session.add(Paper(**paper_dict))
paper_count += 1
# Export hierarchical label cache
for entry in self._session.execute(select(HierarchicalLabelCache)).scalars():
entry_dict = {c.name: getattr(entry, c.name) for c in HierarchicalLabelCache.__table__.columns}
export_session.add(HierarchicalLabelCache(**entry_dict))
# Export embeddings metadata
for entry in self._session.execute(select(EmbeddingsMetadata)).scalars():
entry_dict = {c.name: getattr(entry, c.name) for c in EmbeddingsMetadata.__table__.columns}
export_session.add(EmbeddingsMetadata(**entry_dict))
export_session.commit()
export_engine.dispose()
return paper_count
except DatabaseError:
raise
except Exception as e:
raise DatabaseError(f"Failed to export papers: {str(e)}") from e
[docs]
def import_papers_from_sqlite(
self,
sqlite_path: "Path",
conference: str,
year: int,
) -> int:
"""
Import papers for a given conference and year from a SQLite file.
Existing papers for the given conference/year are **replaced** (not
merged). Hierarchical label cache entries that match the conference
and year are replaced. Embeddings metadata is validated for
consistency (the embedding model must match).
Parameters
----------
sqlite_path : Path
Path to the source SQLite file.
conference : str
Conference name being imported.
year : int
Year being imported.
Returns
-------
int
Number of papers imported.
Raises
------
DatabaseError
If the import fails or the embedding model is inconsistent.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
source_engine = create_engine(
f"sqlite:///{sqlite_path}",
connect_args={"check_same_thread": False},
)
paper_count = 0
with Session(source_engine) as source_session:
# --- Validate EmbeddingsMetadata consistency ---
imported_meta = source_session.execute(select(EmbeddingsMetadata)).scalars().first()
if imported_meta:
existing_meta = self._session.execute(select(EmbeddingsMetadata)).scalars().first()
if existing_meta and normalize_model_name(existing_meta.embedding_model) != normalize_model_name(
imported_meta.embedding_model
):
raise EmbeddingModelConflictError(
existing_meta.embedding_model, imported_meta.embedding_model
)
# Delete existing papers for this conference+year
self._session.execute(delete(Paper).where(and_(Paper.conference == conference, Paper.year == year)))
# Delete only hierarchical label cache entries whose
# embedding_model matches one from the imported data
imported_models_result = (
source_session.execute(select(HierarchicalLabelCache.embedding_model).distinct()).scalars().all()
)
if imported_models_result:
self._session.execute(
delete(HierarchicalLabelCache).where(
HierarchicalLabelCache.embedding_model.in_(imported_models_result)
)
)
self._session.commit()
# Import papers — use merge() to handle any UID collisions
# (e.g. the same paper existing under a different conference
# casing that the DELETE above didn't catch).
for paper in source_session.execute(select(Paper)).scalars():
paper_dict = {c.name: getattr(paper, c.name) for c in Paper.__table__.columns}
self._session.merge(Paper(**paper_dict))
paper_count += 1
# Import hierarchical labels
for entry in source_session.execute(select(HierarchicalLabelCache)).scalars():
entry_dict = {c.name: getattr(entry, c.name) for c in HierarchicalLabelCache.__table__.columns}
self._session.add(HierarchicalLabelCache(**entry_dict))
# Import embeddings metadata (only if not already present)
if imported_meta:
existing_meta = self._session.execute(select(EmbeddingsMetadata)).scalars().first()
if not existing_meta:
meta_dict = {
c.name: getattr(imported_meta, c.name) for c in EmbeddingsMetadata.__table__.columns
}
self._session.add(EmbeddingsMetadata(**meta_dict))
self._session.commit()
source_engine.dispose()
return paper_count
except DatabaseError:
raise
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to import papers: {str(e)}") from e
# ------------------------------------------------------------------
# Clustering cache JSON export / import
# ------------------------------------------------------------------
[docs]
def export_clustering_cache_to_json(
self,
conference: str,
year: int,
) -> Dict[str, Any]:
"""
Export clustering cache entries matching *conference* and *year* as JSON.
Parameters
----------
conference : str
Conference name to match.
year : int
Year to match.
Returns
-------
dict
A JSON-serialisable dictionary with an ``entries`` list.
Each entry contains all :class:`ClusteringCache` columns except
``id`` (auto-generated on import).
Raises
------
DatabaseError
If the export fails.
"""
import json as _json
if not self._session:
raise DatabaseError("Not connected to database")
try:
entries: List[Dict[str, Any]] = []
stmt = select(ClusteringCache).where(
and_(
ClusteringCache.conference == conference,
ClusteringCache.year == year,
)
)
for entry in self._session.execute(stmt).scalars():
row: Dict[str, Any] = {}
for col in ClusteringCache.__table__.columns:
if col.name == "id":
continue # skip PK; it will be auto-generated on import
val = getattr(entry, col.name)
if col.name in ("clustering_params", "results_json") and isinstance(val, str):
val = _json.loads(val)
elif col.name == "created_at" and val is not None:
val = val.isoformat()
row[col.name] = val
entries.append(row)
return {"entries": entries}
except DatabaseError:
raise
except Exception as e:
raise DatabaseError(f"Failed to export clustering cache: {str(e)}") from e
[docs]
def import_clustering_cache_from_json(
self,
data: Dict[str, Any],
conference: str,
year: int,
overwrite_embedding_model: Optional[str] = None,
) -> int:
"""
Import clustering cache entries from a JSON dictionary.
Existing clustering cache entries matching *conference* and *year*
are deleted before importing the new entries.
Parameters
----------
data : dict
Dictionary previously returned by
:meth:`export_clustering_cache_to_json`.
conference : str
Conference name for scoping the delete.
year : int
Year for scoping the delete.
overwrite_embedding_model : str, optional
When provided, the ``embedding_model`` field of every imported
entry is replaced with this value. Use this when importing an
artifact whose embedding model differs from the locally
configured one (i.e. when ``--ignore-embedding-model-mismatch``
was passed) so that :meth:`get_clustering_cache` can find the
entries using the local model name.
Returns
-------
int
Number of cache entries imported.
Raises
------
DatabaseError
If the import fails.
"""
import json as _json
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Delete existing matching entries using column-based filtering.
self._session.execute(
delete(ClusteringCache).where(
and_(
ClusteringCache.conference == conference,
ClusteringCache.year == year,
)
)
)
self._session.flush()
# For PostgreSQL the auto-increment sequence can fall out of sync when rows
# were previously inserted with explicit primary-key values (e.g. by an older
# version of import_papers_from_sqlite that included the clustering cache).
# If other rows for a different conference/year remain in the table and the
# sequence tries to reuse one of their IDs, the next INSERT will raise
# UniqueViolation. Reset the sequence to max(existing id)+1 so that every
# new row gets a safe ID regardless of sequence history.
db_url = self.database_url.lower()
if "postgresql" in db_url or "postgres" in db_url:
self._session.execute(
text(
"SELECT setval("
" pg_get_serial_sequence('clustering_cache', 'id'),"
" COALESCE((SELECT MAX(id) FROM clustering_cache), 0) + 1,"
" false"
")"
)
)
self._session.flush()
count = 0
for item in data.get("entries", []):
row = dict(item) # shallow copy
# Re-serialise parsed JSON fields back to strings for DB storage
if "clustering_params" in row and row["clustering_params"] is not None:
if not isinstance(row["clustering_params"], str):
row["clustering_params"] = _json.dumps(row["clustering_params"])
if "results_json" in row and row["results_json"] is not None:
if not isinstance(row["results_json"], str):
row["results_json"] = _json.dumps(row["results_json"])
# Convert ISO-format created_at back to datetime
if "created_at" in row and isinstance(row["created_at"], str):
row["created_at"] = datetime.fromisoformat(row["created_at"])
# Drop 'id' if present — let the DB auto-generate it
row.pop("id", None)
# Replace embedding_model when instructed (e.g. mismatch ignored)
if overwrite_embedding_model is not None:
row["embedding_model"] = overwrite_embedding_model
# Ensure conference/year columns are set from the import scope
row["conference"] = conference
row["year"] = year
# Fill n_clusters from results statistics if not set
if row.get("n_clusters") is None:
results = row.get("results_json")
if isinstance(results, str):
try:
results = _json.loads(results)
except (ValueError, TypeError):
results = {}
if isinstance(results, dict):
actual = results.get("statistics", {}).get("n_clusters")
if actual is not None:
row["n_clusters"] = int(actual)
self._session.add(ClusteringCache(**row))
count += 1
self._session.commit()
return count
except DatabaseError:
raise
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to import clustering cache: {str(e)}") from e