"""
Database Module
===============
This module provides functionality to load JSON data into a SQL database.
Supports both SQLite and PostgreSQL backends via SQLAlchemy.
"""
import hashlib
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
from sqlalchemy import create_engine, select, func, or_, and_, text
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError, ProgrammingError, IntegrityError
# Import Pydantic models from plugin framework
from abstracts_explorer.plugin import LightweightPaper
# Import SQLAlchemy models
from abstracts_explorer.db_models import Base, Paper, EmbeddingsMetadata, ClusteringCache, ValidationData
logger = logging.getLogger(__name__)
[docs]
class DatabaseError(Exception):
"""Exception raised for database operations."""
pass
[docs]
class DatabaseManager:
"""
Manager for SQL database operations using SQLAlchemy.
Supports SQLite and PostgreSQL backends through SQLAlchemy connection URLs.
Database configuration is read from the config file (PAPER_DB variable).
Attributes
----------
database_url : str
SQLAlchemy database URL from configuration.
engine : Engine or None
SQLAlchemy engine instance.
SessionLocal : sessionmaker or None
SQLAlchemy session factory.
_session : Session or None
Active database session if connected.
Examples
--------
>>> # Database configuration comes from config file
>>> db = DatabaseManager()
>>> db.connect()
>>> db.create_tables()
>>> db.close()
"""
[docs]
def __init__(self):
"""
Initialize the DatabaseManager.
Reads database configuration from the config file.
"""
from abstracts_explorer.config import get_config
config = get_config()
self.database_url = config.database_url
self.engine: Optional[Engine] = None
self.SessionLocal: Optional[sessionmaker] = None
self._session: Optional[Session] = None
self.connection = None # Legacy attribute for backward compatibility (always None now)
[docs]
def connect(self) -> None:
"""
Connect to the database.
Creates the database file if it doesn't exist (SQLite only).
Raises
------
DatabaseError
If connection fails.
"""
try:
# Create parent directories for SQLite
if self.database_url.startswith("sqlite:///"):
db_path_str = self.database_url.replace("sqlite:///", "")
db_path = Path(db_path_str)
db_path.parent.mkdir(parents=True, exist_ok=True)
# Create engine with appropriate settings
connect_args = {}
if self.database_url.startswith("sqlite"):
# SQLite-specific settings
connect_args = {"check_same_thread": False}
self.engine = create_engine(
self.database_url,
connect_args=connect_args,
echo=False, # Set to True for SQL debugging
)
# Create session factory
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=self.engine,
)
# Create a session
self._session = self.SessionLocal()
# Set legacy connection attribute to provide raw database connection for backward compatibility
# This allows tests to use .connection.cursor()
self._raw_connection = self.engine.raw_connection()
self.connection = self._raw_connection.driver_connection
logger.debug(f"Connected to database: {self._mask_url(self.database_url)}")
except Exception as e:
raise DatabaseError(f"Failed to connect to database: {str(e)}") from e
def _mask_url(self, url: str) -> str:
"""Mask password in URL for logging."""
if "@" in url and ":" in url:
parts = url.split("@")
if len(parts) == 2:
before_at = parts[0]
if "://" in before_at:
protocol_user = before_at.rsplit(":", 1)[0]
return f"{protocol_user}:***@{parts[1]}"
return url
[docs]
def close(self) -> None:
"""
Close the database connection.
Does nothing if not connected.
"""
if self._session:
self._session.close()
self._session = None
if hasattr(self, '_raw_connection') and self._raw_connection:
self._raw_connection.close()
self._raw_connection = None
if self.engine:
self.engine.dispose()
self.engine = None
self.SessionLocal = None
self.connection = None
logger.debug("Database connection closed")
[docs]
def __enter__(self):
"""Context manager entry."""
self.connect()
return self
[docs]
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()
[docs]
def create_tables(self) -> None:
"""
Create database tables for papers and embeddings metadata.
Creates the following tables:
- papers: Main table for paper information with lightweight ML4PS schema
- embeddings_metadata: Metadata about embeddings (model used, creation date)
- clustering_cache: Cache for clustering results
This method is idempotent - it can be called multiple times without error.
Tables are only created if they don't already exist.
Raises
------
DatabaseError
If table creation fails.
"""
if not self.engine:
raise DatabaseError("Not connected to database")
try:
# Create tables only if they don't exist (checkfirst=True is the default)
# This makes the operation idempotent
Base.metadata.create_all(bind=self.engine, checkfirst=True)
logger.debug("Database tables created successfully")
except (OperationalError, ProgrammingError, IntegrityError) as e:
# These exceptions can occur when tables already exist, especially with:
# - Race conditions in concurrent environments
# - PostgreSQL's pg_type_typname_nsp_index constraint
# - SQLite "table already exists" errors
error_msg = str(e).lower()
if any(x in error_msg for x in ["already exists", "duplicate", "pg_type_typname_nsp_index"]):
# Tables already exist - this is fine, just log it
logger.debug(f"Tables already exist (this is normal): {str(e)}")
return
# For other database errors, re-raise with context
raise DatabaseError(f"Failed to create tables: {str(e)}") from e
except Exception as e:
# Catch any other unexpected errors
raise DatabaseError(f"Failed to create tables: {str(e)}") from e
[docs]
def add_paper(self, paper: LightweightPaper) -> Optional[str]:
"""
Add a single paper to the database.
Parameters
----------
paper : LightweightPaper
Validated paper object to insert.
Returns
-------
str or None
The UID of the inserted paper, or None if paper was skipped (duplicate).
Raises
------
DatabaseError
If insertion fails.
Examples
--------
>>> from abstracts_explorer.plugin import LightweightPaper
>>> db = DatabaseManager()
>>> with db:
... db.create_tables()
... paper = LightweightPaper(
... title="Test Paper",
... authors=["John Doe"],
... abstract="Test abstract",
... session="Session 1",
... poster_position="P1",
... year=2025,
... conference="NeurIPS"
... )
... paper_uid = db.add_paper(paper)
>>> print(f"Inserted paper with UID: {paper_uid}")
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Extract validated fields from LightweightPaper
paper_id = paper.original_id if paper.original_id else None
title = paper.title
abstract = paper.abstract
# Handle authors - store as semicolon-separated names
authors_data = paper.authors
if isinstance(authors_data, list):
authors_str = "; ".join(str(author) for author in authors_data)
else:
authors_str = str(authors_data) if authors_data else ""
# Generate UID as hash from title + conference + year
uid_source = f"{title}:{paper_id}:{paper.conference}:{paper.year}"
uid = hashlib.sha256(uid_source.encode("utf-8")).hexdigest()[:16]
# Check if paper already exists (by UID)
existing = self._session.execute(
select(Paper).where(Paper.uid == uid)
).scalar_one_or_none()
if existing:
logger.debug(f"Skipping duplicate paper: {title} (uid: {uid})")
return None
# Handle keywords (could be list or None)
keywords_list = paper.keywords
keywords_str: str
if isinstance(keywords_list, list):
keywords_str = ", ".join(str(k) for k in keywords_list)
elif keywords_list is None:
keywords_str = ""
else:
keywords_str = ""
# Use paper's original_id if available
original_id = str(paper.original_id) if paper.original_id else None
# Create Paper ORM object
new_paper = Paper(
uid=uid,
original_id=original_id,
title=title,
authors=authors_str,
abstract=abstract,
session=paper.session,
poster_position=paper.poster_position,
paper_pdf_url=paper.paper_pdf_url,
poster_image_url=paper.poster_image_url,
url=paper.url,
room_name=paper.room_name,
keywords=keywords_str,
starttime=paper.starttime,
endtime=paper.endtime,
award=paper.award,
year=paper.year,
conference=paper.conference,
)
self._session.add(new_paper)
self._session.commit()
return uid
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to add paper: {str(e)}") from e
[docs]
def add_papers(self, papers: List[LightweightPaper]) -> int:
"""
Add multiple papers to the database in a batch.
Parameters
----------
papers : list of LightweightPaper
List of validated paper objects to insert.
Returns
-------
int
Number of papers successfully inserted (excludes duplicates).
Raises
------
DatabaseError
If batch insertion fails.
Examples
--------
>>> from abstracts_explorer.plugin import LightweightPaper
>>> db = DatabaseManager()
>>> with db:
... db.create_tables()
... papers = [
... LightweightPaper(
... title="Paper 1",
... authors=["Author 1"],
... abstract="Abstract 1",
... session="Session 1",
... poster_position="P1",
... year=2025,
... conference="NeurIPS"
... ),
... LightweightPaper(
... title="Paper 2",
... authors=["Author 2"],
... abstract="Abstract 2",
... session="Session 2",
... poster_position="P2",
... year=2025,
... conference="NeurIPS"
... )
... ]
... count = db.add_papers(papers)
>>> print(f"Inserted {count} papers")
"""
if not self._session:
raise DatabaseError("Not connected to database")
inserted_count = 0
for paper in papers:
result = self.add_paper(paper)
if result is not None:
inserted_count += 1
logger.debug(f"Successfully inserted {inserted_count} of {len(papers)} papers")
return inserted_count
[docs]
def donate_validation_data(self, paper_priorities: Dict[str, Dict[str, Any]]) -> int:
"""
Store donated paper rating data for validation purposes.
This method accepts anonymized paper ratings from users and stores them
in the validation_data table for improving the service.
Parameters
----------
paper_priorities : Dict[str, Dict[str, Any]]
Dictionary mapping paper UIDs to priority data.
Each priority data dict must contain:
- priority (int): Rating value
- searchTerm (str, optional): Search term associated with the rating
Returns
-------
int
Number of papers successfully donated
Raises
------
ValueError
If paper_priorities is empty or contains invalid data format
DatabaseError
If database operation fails
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... priorities = {
... "abc123": {"priority": 5, "searchTerm": "machine learning"},
... "def456": {"priority": 4, "searchTerm": "deep learning"}
... }
... count = db.donate_validation_data(priorities)
... print(f"Donated {count} papers")
"""
if not paper_priorities:
raise ValueError("No data provided")
if self._session is None:
raise DatabaseError("Not connected to database")
try:
donated_count = 0
for paper_uid, priority_data in paper_priorities.items():
# Validate data format
if not isinstance(priority_data, dict):
raise ValueError(
"Invalid data format. Expected dict with priority and searchTerm"
)
priority = priority_data.get("priority", 0)
search_term = priority_data.get("searchTerm", None)
# Create validation data entry
validation_entry = ValidationData(
paper_uid=paper_uid,
priority=priority,
search_term=search_term
)
self._session.add(validation_entry)
donated_count += 1
# Commit all changes
self._session.commit()
logger.info(f"Successfully donated {donated_count} papers to validation data")
return donated_count
except ValueError:
# Re-raise validation errors without rollback
raise
except Exception as e:
self._session.rollback()
logger.error(f"Error donating validation data: {e}", exc_info=True)
raise DatabaseError(f"Failed to donate validation data: {e}")
[docs]
def query(self, sql: str, parameters: tuple = ()) -> List[Dict[str, Any]]:
"""
Execute a SQL query and return results.
Note: This method provides backward compatibility with raw SQL queries.
For new code, prefer using SQLAlchemy ORM methods.
Parameters
----------
sql : str
SQL query to execute (use named parameters like :param1, :param2).
parameters : tuple, optional
Query parameters for parameterized queries.
Returns
-------
list of dict
Query results as list of dictionaries.
Raises
------
DatabaseError
If query execution fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... results = db.query("SELECT * FROM papers WHERE session = ?", ("Poster",))
>>> for row in results:
... print(row['title'])
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Convert ? placeholders to :0, :1, :2 for SQLAlchemy
# Count the number of ? placeholders
param_count = sql.count("?")
# Replace ? with numbered parameters
converted_sql = sql
for i in range(param_count):
converted_sql = converted_sql.replace("?", f":param{i}", 1)
# Create parameter dict
param_dict = {f"param{i}": parameters[i] for i in range(len(parameters))}
# Execute raw SQL using text()
result = self._session.execute(text(converted_sql), param_dict)
# Convert result to list of dicts
rows = []
for row in result:
# Convert row to dict
row_dict = dict(row._mapping)
rows.append(row_dict)
return rows
except Exception as e:
raise DatabaseError(f"Query failed: {str(e)}") from e
[docs]
def get_paper_count(self) -> int:
"""
Get the total number of papers in the database.
Returns
-------
int
Number of papers.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
count = self._session.execute(
select(func.count()).select_from(Paper)
).scalar()
return count or 0
except Exception as e:
raise DatabaseError(f"Failed to count papers: {str(e)}") from e
[docs]
def search_papers(
self,
keyword: Optional[str] = None,
session: Optional[str] = None,
sessions: Optional[List[str]] = None,
year: Optional[int] = None,
years: Optional[List[int]] = None,
conference: Optional[str] = None,
conferences: Optional[List[str]] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""
Search for papers by various criteria (lightweight schema).
Parameters
----------
keyword : str, optional
Keyword to search in title, abstract, or keywords fields.
session : str, optional
Single session to filter by (deprecated, use sessions instead).
sessions : list[str], optional
List of sessions to filter by (matches ANY).
year : int, optional
Single year to filter by (deprecated, use years instead).
years : list[int], optional
List of years to filter by (matches ANY).
conference : str, optional
Single conference to filter by (deprecated, use conferences instead).
conferences : list[str], optional
List of conferences to filter by (matches ANY).
limit : int, default=100
Maximum number of results to return.
Returns
-------
list of dict
Matching papers as dictionaries.
Raises
------
DatabaseError
If search fails.
Examples
--------
>>> db = DatabaseManager("neurips.db")
>>> with db:
... papers = db.search_papers(keyword="neural network", limit=10)
>>> for paper in papers:
... print(paper['title'])
>>> # Search with multiple sessions
>>> papers = db.search_papers(sessions=["Session 1", "Session 2"])
>>> # Search with years
>>> papers = db.search_papers(years=[2024, 2025])
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Build query conditions
conditions = []
if keyword:
search_pattern = f"%{keyword}%"
conditions.append(
or_(
Paper.title.ilike(search_pattern),
Paper.abstract.ilike(search_pattern),
Paper.keywords.ilike(search_pattern),
)
)
# Handle sessions (prefer list form, fall back to single)
session_list = sessions if sessions else ([session] if session else [])
if session_list:
conditions.append(Paper.session.in_(session_list))
# Handle years (prefer list form, fall back to single)
year_list = years if years else ([year] if year else [])
if year_list:
conditions.append(Paper.year.in_(year_list))
# Handle conferences (prefer list form, fall back to single)
conference_list = conferences if conferences else ([conference] if conference else [])
if conference_list:
conditions.append(Paper.conference.in_(conference_list))
# Build query
stmt = select(Paper)
if conditions:
stmt = stmt.where(and_(*conditions))
if limit:
stmt = stmt.limit(limit)
# Execute query
results = self._session.execute(stmt).scalars().all()
# Convert ORM objects to dicts
return [self._paper_to_dict(paper) for paper in results]
except Exception as e:
raise DatabaseError(f"Search failed: {str(e)}") from e
[docs]
def search_papers_keyword(
self,
query: str,
limit: int = 10,
sessions: Optional[List[str]] = None,
years: Optional[List[int]] = None,
conferences: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Perform keyword-based search with filtering and author parsing.
This is a convenience method that wraps search_papers and formats
the results for web API consumption, including author parsing.
Parameters
----------
query : str
Keyword to search in title, abstract, or keywords fields
limit : int, optional
Maximum number of results, by default 10
sessions : list of str, optional
Filter by paper sessions
years : list of int, optional
Filter by publication years
conferences : list of str, optional
Filter by conference names
Returns
-------
list of dict
List of paper dictionaries with parsed authors
Examples
--------
>>> papers = db.search_papers_keyword(
... "neural networks",
... limit=5,
... years=[2024, 2025]
... )
"""
# Keyword search in database with multiple filter support
papers = self.search_papers(
keyword=query,
sessions=sessions,
years=years,
conferences=conferences,
limit=limit,
)
# Convert to list of dicts for JSON serialization
papers = [dict(p) for p in papers]
# Parse authors from comma-separated string for each paper
for paper in papers:
if "authors" in paper and paper["authors"]:
paper["authors"] = [a.strip() for a in paper["authors"].split(";")]
else:
paper["authors"] = []
return papers
[docs]
def get_stats(
self,
year: Optional[int] = None,
conference: Optional[str] = None
) -> Dict[str, Any]:
"""
Get database statistics, optionally filtered by year and conference.
Parameters
----------
year : int, optional
Filter by specific year
conference : str, optional
Filter by specific conference
Returns
-------
dict
Statistics dictionary with:
- total_papers: int - Number of papers matching filters
- year: int or None - Filter year if provided
- conference: str or None - Filter conference if provided
Examples
--------
>>> stats = db.get_stats()
>>> print(f"Total papers: {stats['total_papers']}")
>>> stats_2024 = db.get_stats(year=2024)
>>> print(f"Papers in 2024: {stats_2024['total_papers']}")
"""
# Build WHERE clause for filtered count
conditions: List[str] = []
parameters: List[Any] = []
if year is not None:
conditions.append("year = ?")
parameters.append(year)
if conference is not None:
conditions.append("conference = ?")
parameters.append(conference)
if conditions:
where_clause = " AND ".join(conditions)
result = self.query(f"SELECT COUNT(*) as count FROM papers WHERE {where_clause}", tuple(parameters))
total_papers = result[0]["count"] if result else 0
else:
total_papers = self.get_paper_count()
return {
"total_papers": total_papers,
"year": year,
"conference": conference,
}
def _paper_to_dict(self, paper: Paper) -> Dict[str, Any]:
"""Convert Paper ORM object to dictionary."""
return {
"uid": paper.uid,
"original_id": paper.original_id,
"title": paper.title,
"authors": paper.authors,
"abstract": paper.abstract,
"session": paper.session,
"poster_position": paper.poster_position,
"paper_pdf_url": paper.paper_pdf_url,
"poster_image_url": paper.poster_image_url,
"url": paper.url,
"room_name": paper.room_name,
"keywords": paper.keywords,
"starttime": paper.starttime,
"endtime": paper.endtime,
"award": paper.award,
"year": paper.year,
"conference": paper.conference,
"created_at": paper.created_at,
}
[docs]
def search_authors_in_papers(
self,
name: Optional[str] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""
Search for authors by name within the papers' authors field.
Parameters
----------
name : str, optional
Name to search for (partial match).
limit : int, default=100
Maximum number of results to return.
Returns
-------
list of dict
Unique authors found in papers with fields: name.
Raises
------
DatabaseError
If search fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... authors = db.search_authors_in_papers(name="Huang")
>>> for author in authors:
... print(author['name'])
"""
if not name or not self._session:
return []
try:
# Search for authors in the semicolon-separated authors field
search_pattern = f"%{name}%"
stmt = (
select(Paper.authors)
.where(Paper.authors.ilike(search_pattern))
.distinct()
.limit(limit * 10) # Get more papers to extract unique authors
)
results = self._session.execute(stmt).scalars().all()
# Extract unique author names
author_names = set()
for authors_str in results:
if authors_str:
# Split semicolon-separated authors
for author in authors_str.split(";"):
author = author.strip()
if name.lower() in author.lower():
author_names.add(author)
if len(author_names) >= limit:
break
if len(author_names) >= limit:
break
return [{"name": name} for name in sorted(author_names)[:limit]]
except Exception as e:
raise DatabaseError(f"Author search failed: {str(e)}") from e
[docs]
def get_author_count(self) -> int:
"""
Get the approximate number of unique authors in the database.
Note: This provides an estimate by counting unique author names
across all papers. The actual count may vary.
Returns
-------
int
Approximate number of unique authors.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Get all author fields
stmt = select(Paper.authors).where(
and_(Paper.authors.isnot(None), Paper.authors != "")
)
results = self._session.execute(stmt).scalars().all()
# Extract unique author names
author_names = set()
for authors_str in results:
if authors_str:
for author in authors_str.split(";"):
author_names.add(author.strip())
return len(author_names)
except Exception as e:
raise DatabaseError(f"Failed to count authors: {str(e)}") from e
[docs]
def get_filter_options(self, year: Optional[int] = None, conference: Optional[str] = None) -> dict:
"""
Get distinct values for filterable fields (lightweight schema).
Returns a dictionary with lists of distinct values for session, year,
and conference fields that can be used to populate filter dropdowns.
Optionally filters by year and/or conference.
Parameters
----------
year : int, optional
Filter results to only show options for this year
conference : str, optional
Filter results to only show options for this conference
Returns
-------
dict
Dictionary with keys 'sessions', 'years', 'conferences' containing
lists of distinct non-null values sorted alphabetically (or numerically for years).
Raises
------
DatabaseError
If query fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... filters = db.get_filter_options()
>>> print(filters['sessions'])
['Session 1', 'Session 2', ...]
>>> print(filters['years'])
[2023, 2024, 2025]
>>> # Get filters for specific year
>>> filters = db.get_filter_options(year=2025)
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Build WHERE conditions
conditions = []
if year is not None:
conditions.append(Paper.year == year)
if conference is not None:
conditions.append(Paper.conference == conference)
# Get distinct sessions (with filters)
stmt = select(Paper.session).distinct()
if conditions:
stmt = stmt.where(and_(*conditions))
stmt = stmt.where(and_(Paper.session.isnot(None), Paper.session != "")).order_by(Paper.session)
sessions_result = self._session.execute(stmt).scalars().all()
sessions = list(sessions_result)
# Get distinct years (not filtered)
years_stmt = (
select(Paper.year)
.distinct()
.where(Paper.year.isnot(None))
.order_by(Paper.year.desc())
)
years_result = self._session.execute(years_stmt).scalars().all()
years = list(years_result)
# Get distinct conferences (not filtered)
conferences_stmt = (
select(Paper.conference)
.distinct()
.where(and_(Paper.conference.isnot(None), Paper.conference != ""))
.order_by(Paper.conference)
)
conferences_result = self._session.execute(conferences_stmt).scalars().all()
conferences = list(conferences_result)
return {
"sessions": sessions,
"years": years,
"conferences": conferences,
}
except Exception as e:
raise DatabaseError(f"Failed to get filter options: {str(e)}") from e
[docs]
def get_embedding_model(self) -> Optional[str]:
"""
Get the embedding model used for the current embeddings.
Returns
-------
str or None
Name of the embedding model, or None if not set.
Raises
------
DatabaseError
If query fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... model = db.get_embedding_model()
>>> print(model)
'text-embedding-qwen3-embedding-4b'
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Get the most recent embedding model entry
stmt = (
select(EmbeddingsMetadata.embedding_model)
.order_by(EmbeddingsMetadata.updated_at.desc())
.limit(1)
)
result = self._session.execute(stmt).scalar_one_or_none()
return result
except Exception as e:
raise DatabaseError(f"Failed to get embedding model: {str(e)}") from e
[docs]
def set_embedding_model(self, model_name: str) -> None:
"""
Set the embedding model used for embeddings.
This stores or updates the embedding model metadata. If a record exists,
it updates the model and timestamp. Otherwise, it creates a new record.
Parameters
----------
model_name : str
Name of the embedding model.
Raises
------
DatabaseError
If update fails.
Examples
--------
>>> db = DatabaseManager()
>>> with db:
... db.set_embedding_model("text-embedding-qwen3-embedding-4b")
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
# Check if any record exists
count_stmt = select(func.count()).select_from(EmbeddingsMetadata)
count = self._session.execute(count_stmt).scalar()
if count and count > 0:
# Update the most recent record
# Get the most recent entry
latest_stmt = (
select(EmbeddingsMetadata)
.order_by(EmbeddingsMetadata.updated_at.desc())
.limit(1)
)
latest = self._session.execute(latest_stmt).scalar_one()
latest.embedding_model = model_name
latest.updated_at = datetime.now(timezone.utc)
else:
# Insert new record
new_metadata = EmbeddingsMetadata(embedding_model=model_name)
self._session.add(new_metadata)
self._session.commit()
logger.info(f"Set embedding model to: {model_name}")
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to set embedding model: {str(e)}") from e
[docs]
def get_clustering_cache(
self,
embedding_model: str,
reduction_method: str,
n_components: int,
clustering_method: str,
n_clusters: Optional[int] = None,
clustering_params: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""
Get cached clustering results matching the parameters.
Parameters
----------
embedding_model : str
Name of the embedding model.
reduction_method : str
Dimensionality reduction method.
n_components : int
Number of components after reduction.
clustering_method : str
Clustering algorithm used.
n_clusters : int, optional
Number of clusters (for kmeans/agglomerative).
clustering_params : dict, optional
Additional clustering parameters (e.g., distance_threshold, eps).
Returns
-------
dict or None
Cached clustering results as dictionary, or None if not found.
Raises
------
DatabaseError
If query fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
import json
# Build query conditions
stmt = select(ClusteringCache).where(
and_(
ClusteringCache.embedding_model == embedding_model,
ClusteringCache.reduction_method == reduction_method,
ClusteringCache.n_components == n_components,
ClusteringCache.clustering_method == clustering_method,
)
)
# Add n_clusters condition if provided
if n_clusters is not None:
stmt = stmt.where(ClusteringCache.n_clusters == n_clusters)
# Get all matching results (we'll filter by params in Python)
stmt = stmt.order_by(ClusteringCache.created_at.desc())
results = self._session.execute(stmt).scalars().all()
if not results:
return None
# If no clustering_params specified, return first match
if clustering_params is None:
if results[0].clustering_params is None:
return json.loads(results[0].results_json)
# If cache has params but query doesn't, consider it a miss
return None
# Filter by clustering_params
params_json = json.dumps(clustering_params, sort_keys=True)
for result in results:
if result.clustering_params is None:
continue
# Compare params (normalize by sorting keys)
cached_params = json.loads(result.clustering_params)
cached_params_json = json.dumps(cached_params, sort_keys=True)
if cached_params_json == params_json:
return json.loads(result.results_json)
return None
except Exception as e:
raise DatabaseError(f"Failed to get clustering cache: {str(e)}") from e
[docs]
def save_clustering_cache(
self,
embedding_model: str,
reduction_method: str,
n_components: int,
clustering_method: str,
results: Dict[str, Any],
n_clusters: Optional[int] = None,
clustering_params: Optional[Dict[str, Any]] = None,
) -> None:
"""
Save clustering results to cache.
Parameters
----------
embedding_model : str
Name of the embedding model.
reduction_method : str
Dimensionality reduction method.
n_components : int
Number of components after reduction.
clustering_method : str
Clustering algorithm used.
results : dict
Clustering results to cache.
n_clusters : int, optional
Number of clusters (for kmeans/agglomerative).
clustering_params : dict, optional
Additional clustering parameters.
Raises
------
DatabaseError
If save fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
import json
# Serialize results and params to JSON
results_json = json.dumps(results)
params_json = json.dumps(clustering_params) if clustering_params else None
# Create new cache entry
cache_entry = ClusteringCache(
embedding_model=embedding_model,
reduction_method=reduction_method,
n_components=n_components,
clustering_method=clustering_method,
n_clusters=n_clusters,
clustering_params=params_json,
results_json=results_json,
)
self._session.add(cache_entry)
self._session.commit()
logger.info(
f"Saved clustering cache: {clustering_method} with {n_clusters} clusters, "
f"model={embedding_model}, reduction={reduction_method}"
)
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to save clustering cache: {str(e)}") from e
[docs]
def clear_clustering_cache(self, embedding_model: Optional[str] = None) -> int:
"""
Clear clustering cache, optionally filtered by embedding model.
This is useful when embeddings change or cache becomes stale.
Parameters
----------
embedding_model : str, optional
If provided, only clear cache for this embedding model.
If None, clear all cache entries.
Returns
-------
int
Number of cache entries deleted.
Raises
------
DatabaseError
If deletion fails.
"""
if not self._session:
raise DatabaseError("Not connected to database")
try:
if embedding_model:
# Delete only for specific model
stmt = select(ClusteringCache).where(
ClusteringCache.embedding_model == embedding_model
)
else:
# Delete all
stmt = select(ClusteringCache)
entries = self._session.execute(stmt).scalars().all()
count = len(entries)
for entry in entries:
self._session.delete(entry)
self._session.commit()
if embedding_model:
logger.info(f"Cleared {count} clustering cache entries for model: {embedding_model}")
else:
logger.info(f"Cleared all {count} clustering cache entries")
return count
except Exception as e:
self._session.rollback()
raise DatabaseError(f"Failed to clear clustering cache: {str(e)}") from e