"""
Embeddings Module
=================
This module provides functionality to generate text embeddings for paper abstracts
and store them in a vector database with paper metadata.
The module uses an OpenAI-compatible API (such as LM Studio or blablador) to generate
embeddings and stores them in ChromaDB for efficient similarity search.
"""
import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import urlparse
from openai import OpenAI
import chromadb
from chromadb.config import Settings
from .config import get_config
from .database import DatabaseManager
logger = logging.getLogger(__name__)
[docs]
class EmbeddingsError(Exception):
"""Exception raised for embedding operations."""
pass
[docs]
class EmbeddingsManager:
"""
Manager for generating and storing text embeddings.
This class handles:
- Connecting to OpenAI-compatible API for embedding generation
- Creating and managing a ChromaDB collection
- Embedding paper abstracts with metadata
- Similarity search operations
Parameters
----------
lm_studio_url : str, optional
URL of the OpenAI-compatible API endpoint, by default "http://localhost:1234"
model_name : str, optional
Name of the embedding model, by default "text-embedding-qwen3-embedding-4b"
collection_name : str, optional
Name of the ChromaDB collection, by default "papers"
Attributes
----------
lm_studio_url : str
OpenAI-compatible API endpoint URL.
model_name : str
Embedding model name.
embedding_db : str
ChromaDB configuration - URL for HTTP service or path for local storage.
collection_name : str
ChromaDB collection name.
client : chromadb.Client or None
ChromaDB client instance.
collection : chromadb.Collection or None
Active ChromaDB collection.
Examples
--------
>>> em = EmbeddingsManager()
>>> em.connect()
>>> em.create_collection()
>>> em.add_paper(paper_dict)
>>> results = em.search_similar("machine learning", n_results=5)
>>> em.close()
"""
[docs]
def __init__(
self,
lm_studio_url: Optional[str] = None,
auth_token: Optional[str] = None,
model_name: Optional[str] = None,
collection_name: Optional[str] = None,
):
"""
Initialize the EmbeddingsManager.
Parameters are optional and will use values from environment/config if not provided.
Parameters
----------
lm_studio_url : str, optional
URL of the OpenAI-compatible API endpoint. If None, uses config value.
model_name : str, optional
Name of the embedding model. If None, uses config value.
collection_name : str, optional
Name of the ChromaDB collection. If None, uses config value.
"""
config = get_config()
self.lm_studio_url = (lm_studio_url or config.llm_backend_url).rstrip("/")
self.llm_backend_auth_token = auth_token or config.llm_backend_auth_token
self.model_name = model_name or config.embedding_model
# Get ChromaDB configuration from config
self.embedding_db = config.embedding_db
self.collection_name = collection_name or config.collection_name
self.client: Optional[Any] = None # chromadb.Client
self.collection: Optional[Any] = None # chromadb.Collection
# OpenAI client - lazy loaded on first use to avoid API calls during test collection
self._openai_client: Optional[OpenAI] = None
@property
def openai_client(self) -> OpenAI:
"""
Get the OpenAI client, creating it lazily on first access.
This lazy loading prevents API calls during test collection.
Returns
-------
OpenAI
Initialized OpenAI client instance.
"""
if self._openai_client is None:
self._openai_client = OpenAI(
base_url=f"{self.lm_studio_url}/v1",
api_key=self.llm_backend_auth_token or "lm-studio-local"
)
return self._openai_client
[docs]
def connect(self) -> None:
"""
Connect to ChromaDB.
Uses HTTP client if embedding_db is a URL, otherwise uses persistent client
with local storage directory.
Raises
------
EmbeddingsError
If connection fails.
"""
try:
if self.embedding_db.startswith("http://") or self.embedding_db.startswith("https://"):
# Use HTTP client for remote ChromaDB service
# Parse URL properly using urllib
parsed = urlparse(self.embedding_db)
host = parsed.hostname or "localhost"
port = parsed.port or 8000
self.client = chromadb.HttpClient(
host=host,
port=port,
settings=Settings(anonymized_telemetry=False),
)
logger.debug(f"Connected to ChromaDB HTTP service at: {self.embedding_db}")
else:
# Use persistent client for local storage
chroma_path = Path(self.embedding_db)
chroma_path.mkdir(parents=True, exist_ok=True)
self.client = chromadb.PersistentClient(
path=str(chroma_path),
settings=Settings(anonymized_telemetry=False),
)
logger.debug(f"Connected to ChromaDB at: {chroma_path}")
except Exception as e:
raise EmbeddingsError(f"Failed to connect to ChromaDB: {str(e)}") from e
[docs]
def close(self) -> None:
"""
Close the ChromaDB connection.
Does nothing if not connected.
"""
if self.client:
self.client = None
self.collection = None
logger.debug("ChromaDB connection closed")
[docs]
def __enter__(self):
"""Context manager entry."""
self.connect()
return self
[docs]
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()
[docs]
def test_lm_studio_connection(self) -> bool:
"""
Test connection to OpenAI-compatible API endpoint.
Returns
-------
bool
True if connection is successful, False otherwise.
Examples
--------
>>> em = EmbeddingsManager()
>>> if em.test_lm_studio_connection():
... print("API is accessible")
"""
try:
# Try to get models list
_ = self.openai_client.models.list()
logger.debug(f"Successfully connected to OpenAI API at {self.lm_studio_url}")
return True
except Exception as e:
logger.warning(f"Failed to connect to OpenAI API: {str(e)}")
return False
[docs]
def generate_embedding(self, text: str) -> List[float]:
"""
Generate embedding for a given text using OpenAI-compatible API.
Parameters
----------
text : str
Text to generate embedding for.
Returns
-------
List[float]
Embedding vector.
Raises
------
EmbeddingsError
If embedding generation fails.
Examples
--------
>>> em = EmbeddingsManager()
>>> embedding = em.generate_embedding("Sample text")
>>> len(embedding)
4096
"""
if not text or not text.strip():
raise EmbeddingsError("Cannot generate embedding for empty text")
try:
response = self.openai_client.embeddings.create(
model=self.model_name,
input=text
)
if not response.data or len(response.data) == 0:
raise EmbeddingsError("No embedding data in API response")
embedding = response.data[0].embedding
logger.debug(f"Generated embedding with dimension: {len(embedding)}")
return embedding
except Exception as e:
raise EmbeddingsError(f"Failed to generate embedding via OpenAI API: {str(e)}") from e
[docs]
def create_collection(self, reset: bool = False) -> None:
"""
Create or get ChromaDB collection.
Parameters
----------
reset : bool, optional
If True, delete existing collection and create new one, by default False
Raises
------
EmbeddingsError
If collection creation fails or not connected.
Examples
--------
>>> em = EmbeddingsManager()
>>> em.connect()
>>> em.create_collection()
>>> em.create_collection(reset=True) # Reset existing collection
"""
if not self.client:
raise EmbeddingsError("Not connected to ChromaDB")
try:
if reset:
try:
self.client.delete_collection(name=self.collection_name)
logger.info(f"Deleted existing collection: {self.collection_name}")
except Exception:
pass # Collection might not exist
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "NeurIPS paper abstracts and metadata"},
)
logger.debug(f"Created/retrieved collection: {self.collection_name}")
except Exception as e:
raise EmbeddingsError(f"Failed to create collection: {str(e)}") from e
[docs]
def paper_exists(self, paper_id: str) -> bool:
"""
Check if a paper already exists in the collection.
Parameters
----------
paper_id : int or str
Unique identifier for the paper.
Returns
-------
bool
True if paper exists in collection, False otherwise.
Raises
------
EmbeddingsError
If collection not initialized.
Examples
--------
>>> em = EmbeddingsManager()
>>> em.connect()
>>> em.create_collection()
>>> em.paper_exists("uid1")
False
>>> em.add_paper(paper_dict)
>>> em.paper_exists("uid1")
True
"""
if not self.collection:
raise EmbeddingsError("Collection not initialized. Call create_collection() first.")
try:
# Try to get the paper by ID
result = self.collection.get(ids=[paper_id])
# If the result has any IDs, the paper exists
return len(result["ids"]) > 0
except Exception as e:
logger.warning(f"Error checking if paper {paper_id} exists: {str(e)}")
return False
[docs]
def paper_needs_update(self, paper: dict) -> bool:
"""
Check if a paper needs to be updated in the collection.
Parameters
----------
paper : dict
Dictionary containing paper information.
Returns
-------
bool
True if the paper needs to be updated, False otherwise.
Raises
------
EmbeddingsError
If collection not initialized.
Examples
--------
>>> em = EmbeddingsManager()
>>> em.connect()
>>> em.create_collection()
>>> em.paper_needs_update({"id": 1, "abstract": "Updated abstract"})
True
>>> em.paper_needs_update({"id": 1, "abstract": "This paper presents..."})
False
"""
if not self.collection:
raise EmbeddingsError("Collection not initialized. Call create_collection() first.")
try:
existing_paper = self.collection.get(ids=[paper["uid"]])
if not existing_paper or len(existing_paper["ids"]) == 0:
return True # Paper does not exist, needs to be added
# Compare existing embedding text with new paper data
existing_documents = existing_paper.get("documents", [])
if not existing_documents:
return True # No document stored, needs update
existing_embedding_text = existing_documents[0]
new_embedding_text = self.embedding_text_from_paper(paper)
return existing_embedding_text != new_embedding_text
except Exception as e:
logger.warning(f"Error checking if paper {paper['uid']} needs update: {str(e)}")
return False
[docs]
@staticmethod
def embedding_text_from_paper(paper: dict) -> str:
"""
Extract text for embedding from a paper dictionary.
Parameters
----------
paper : dict
Dictionary containing paper information.
Returns
-------
str
Text to be used for embedding.
"""
title = paper.get("title", "") or ""
abstract = paper.get("abstract", "") or ""
embedding_text = f"{title}\n\n{abstract}".strip()
if not embedding_text:
raise ValueError(f"Cannot create embedding text for paper {paper['uid']}: no abstract and no title")
return embedding_text
[docs]
def add_paper(self, paper: dict) -> None:
"""
Add a paper to the vector database.
Parameters
----------
paper : dict
Dictionary containing paper information. Must follow the paper database schema.
Raises
------
EmbeddingsError
If adding paper fails or collection not initialized.
Examples
--------
>>> em = EmbeddingsManager()
>>> em.connect()
>>> em.create_collection()
>>> em.add_paper(paper_dict)
"""
if not self.collection:
raise EmbeddingsError("Collection not initialized. Call create_collection() first.")
try:
embedding_text = self.embedding_text_from_paper(paper)
# Generate embedding if not provided
embedding = self.generate_embedding(embedding_text)
# Prepare metadata - convert all values to strings for ChromaDB compatibility
meta = paper.copy()
meta = {k: str(v) if v is not None else "" for k, v in meta.items()}
# Add to collection
self.collection.add(
embeddings=[embedding],
documents=[embedding_text],
metadatas=[meta],
ids=[paper["uid"]],
)
logger.debug(f"Added paper {paper['uid']} to collection")
except Exception as e:
raise EmbeddingsError(f"Failed to add paper {paper['uid']}: {str(e)}") from e
[docs]
def search_similar(
self,
query: str,
n_results: int = 10,
where: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Search for similar papers using semantic similarity.
Parameters
----------
query : str
Query text to search for.
n_results : int, optional
Number of results to return, by default 10
where : dict, optional
Metadata filter conditions.
Returns
-------
dict
Search results containing ids, distances, documents, and metadatas.
Raises
------
EmbeddingsError
If search fails or collection not initialized.
Examples
--------
>>> em = EmbeddingsManager()
>>> em.connect()
>>> em.create_collection()
>>> results = em.search_similar("deep learning transformers", n_results=5, where={"year": 2025})
>>> for i, paper_id in enumerate(results['ids'][0]):
... print(f"{i+1}. Paper {paper_id}: {results['metadatas'][0][i]}")
"""
if not self.collection:
raise EmbeddingsError("Collection not initialized. Call create_collection() first.")
if not query or not query.strip():
raise EmbeddingsError("Query cannot be empty")
try:
# Generate embedding for query
query_embedding = self.generate_embedding(query)
# Search in collection
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where,
)
logger.info(f"Found {len(results['ids'][0])} similar papers")
return dict(results) # type: ignore[arg-type]
except Exception as e:
raise EmbeddingsError(f"Failed to search: {str(e)}") from e
[docs]
def get_collection_stats(self) -> Dict[str, Any]:
"""
Get statistics about the collection.
Returns
-------
dict
Statistics including count, name, and metadata.
Raises
------
EmbeddingsError
If collection not initialized.
Examples
--------
>>> em = EmbeddingsManager()
>>> em.connect()
>>> em.create_collection()
>>> stats = em.get_collection_stats()
>>> print(f"Collection has {stats['count']} papers")
"""
if not self.collection:
raise EmbeddingsError("Collection not initialized. Call create_collection() first.")
try:
return {
"name": self.collection.name,
"count": self.collection.count(),
"metadata": self.collection.metadata,
}
except Exception as e:
raise EmbeddingsError(f"Failed to get collection stats: {str(e)}") from e
[docs]
def check_model_compatibility(self) -> Tuple[bool, Optional[str], Optional[str]]:
"""
Check if the current embedding model matches the one stored in the database.
Returns
-------
tuple of (bool, str or None, str or None)
- compatible: True if models match or no model is stored, False if they differ
- stored_model: Name of the model stored in the database, or None if not set
- current_model: Name of the current model
Raises
------
EmbeddingsError
If database operations fail.
Examples
--------
>>> em = EmbeddingsManager()
>>> compatible, stored, current = em.check_model_compatibility()
>>> if not compatible:
... print(f"Model mismatch: stored={stored}, current={current}")
"""
try:
# Use DatabaseManager to check the stored model
db_manager = DatabaseManager()
db_manager.connect()
stored_model = db_manager.get_embedding_model()
db_manager.close()
# If no model is stored, consider it compatible (first time embedding)
if stored_model is None:
return True, None, self.model_name
# Check if models match
compatible = stored_model == self.model_name
return compatible, stored_model, self.model_name
except Exception as e:
raise EmbeddingsError(f"Failed to check model compatibility: {str(e)}") from e
[docs]
def embed_from_database(
self,
where_clause: Optional[str] = None,
progress_callback: Optional[Callable[[int, int], None]] = None,
force_recreate: bool = False,
) -> int:
"""
Embed papers from the database.
Reads papers from the database and generates embeddings for their abstracts.
Parameters
----------
where_clause : str, optional
SQL WHERE clause to filter papers (e.g., "decision = 'Accept'")
progress_callback : callable, optional
Callback function to report progress. Called with (current, total) number of papers processed.
force_recreate : bool, optional
If True, skip checking for existing embeddings and recreate all, by default False
Returns
-------
int
Number of papers successfully embedded.
Raises
------
EmbeddingsError
If database reading or embedding fails.
Examples
--------
>>> em = EmbeddingsManager()
>>> em.connect()
>>> em.create_collection()
>>> count = em.embed_from_database()
>>> print(f"Embedded {count} papers")
>>> # Only embed accepted papers
>>> count = em.embed_from_database(where_clause="decision = 'Accept'")
"""
if not self.collection:
raise EmbeddingsError("Collection not initialized. Call create_collection() first.")
try:
# Use DatabaseManager for database operations
db_manager = DatabaseManager()
db_manager.connect()
# Store the embedding model in the database
db_manager.set_embedding_model(self.model_name)
query = "SELECT * FROM papers"
if where_clause:
query += f" WHERE {where_clause}"
rows = db_manager.query(query)
total = len(rows)
logger.debug(f"Found {total} papers to embed")
if total == 0:
db_manager.close()
return 0
# Process papers one by one
embedded_count = 0
skipped_count = 0
for i, row in enumerate(rows):
# Convert sqlite3.Row to dict
paper = dict(row)
# Check if paper already exists in the collection and if it needs to be updated
# Skip this check if force_recreate is True
if not force_recreate and not self.paper_needs_update(paper):
logger.debug(f"Skipping paper {paper['uid']}: already exists in collection")
skipped_count += 1
# Still call progress callback to update the progress bar
if progress_callback:
progress_callback(i + 1, total)
continue
else:
try:
self.add_paper(paper)
embedded_count += 1
# Call progress callback if provided
if progress_callback:
progress_callback(i + 1, total)
except Exception as e:
logger.error(f"Failed to embed paper {paper['uid']}: {str(e)}")
continue
db_manager.close()
logger.info(f"Successfully embedded {embedded_count} papers, skipped {skipped_count} existing papers")
return embedded_count
except Exception as e:
raise EmbeddingsError(f"Failed to embed from database: {str(e)}") from e
[docs]
def search_papers_semantic(
self,
query: str,
database,
limit: int = 10,
sessions: Optional[List[str]] = None,
years: Optional[List[int]] = None,
conferences: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""
Perform semantic search for papers using embeddings.
This function combines embedding-based similarity search with metadata filtering
and retrieves complete paper information from the database.
Parameters
----------
query : str
Search query text
database : DatabaseManager
Database manager for retrieving full paper details
limit : int, optional
Maximum number of results to return, by default 10
sessions : list of str, optional
Filter by paper sessions
years : list of int, optional
Filter by publication years
conferences : list of str, optional
Filter by conference names
Returns
-------
list of dict
List of paper dictionaries with complete information
Raises
------
EmbeddingsError
If search fails
Examples
--------
>>> papers = em.search_papers_semantic(
... "transformers in vision",
... database=db,
... limit=5,
... years=[2024, 2025]
... )
"""
from .paper_utils import format_search_results, PaperFormattingError
# Build metadata filter for embeddings search
# NOTE: All metadata is stored as strings in ChromaDB (see add_paper method, line 445)
# so we must convert filter values to strings for matching
filter_conditions: List[Dict[str, Any]] = []
if sessions:
filter_conditions.append({"session": {"$in": sessions}})
if years:
# Convert years to strings to match ChromaDB metadata storage format
year_strs: List[str] = [str(y) for y in years]
filter_conditions.append({"year": {"$in": year_strs}})
if conferences:
filter_conditions.append({"conference": {"$in": conferences}})
# Use $and operator if multiple conditions, otherwise use single condition
where_filter: Optional[Dict[str, Any]] = None
if len(filter_conditions) > 1:
where_filter = {"$and": filter_conditions}
elif len(filter_conditions) == 1:
where_filter = filter_conditions[0]
logger.info(f"Semantic search - query: {query}, filter: sessions={sessions}, years={years}, conferences={conferences}")
logger.info(f"Where filter: {where_filter}")
# Get more results initially to account for filtering
results = self.search_similar(query, n_results=limit * 2, where=where_filter)
logger.info(f"Search results count: {len(results.get('ids', [[]])[0]) if results else 0}")
# Transform ChromaDB results to paper format using shared utility
try:
papers = format_search_results(results, database, include_documents=False)
except PaperFormattingError:
# No valid papers found
return []
# Limit results (filtering already done at database level)
return papers[:limit]
[docs]
def find_papers_within_distance(
self,
database,
query: str,
distance_threshold: float = 1.1,
conferences: Optional[List[str]] = None,
years: Optional[List[int]] = None,
) -> Dict[str, Any]:
"""
Find papers within a specified distance from a custom search query.
This method treats the search query as a clustering center and returns
papers within the specified Euclidean distance radius in embedding space.
Parameters
----------
database : DatabaseManager
Database manager instance for retrieving paper details
query : str
The search query text
distance_threshold : float, optional
Euclidean distance radius, by default 1.1
conferences : list[str], optional
Filter results to only include papers from these conferences
years : list[int], optional
Filter results to only include papers from these years
Returns
-------
dict
Dictionary containing:
- query: str - The search query
- query_embedding: list[float] - The generated embedding for the query
- distance: float - The distance threshold used
- papers: list[dict] - Papers within the distance radius with their distances
- count: int - Number of papers found
Raises
------
EmbeddingsError
If embeddings collection is empty or operation fails
Examples
--------
>>> em = EmbeddingsManager()
>>> em.connect()
>>> em.create_collection()
>>> db = DatabaseManager()
>>> db.connect()
>>> results = em.find_papers_within_distance(db, "machine learning", 1.1)
>>> print(f"Found {results['count']} papers")
>>>
>>> # With filters
>>> results = em.find_papers_within_distance(
... db, "deep learning", 1.1,
... conferences=["NeurIPS"],
... years=[2023, 2024]
... )
"""
from abstracts_explorer.paper_utils import get_paper_with_authors, PaperFormattingError
if not self.collection:
raise EmbeddingsError("Collection not initialized. Call create_collection() first.")
if not query or not query.strip():
raise EmbeddingsError("Query cannot be empty")
try:
# Generate embedding for the query
query_embedding = self.generate_embedding(query)
# Get total count of papers in collection
total_count = self.collection.count()
if total_count == 0:
raise EmbeddingsError("No papers in collection")
# Build where clause for filtering
where_clause: Optional[Dict[str, Any]] = None
if conferences or years:
filters: list[Dict[str, Any]] = []
if conferences:
if len(conferences) == 1:
filters.append({"conference": conferences[0]})
else:
filters.append({"conference": {"$in": conferences}})
if years:
if len(years) == 1:
filters.append({"year": years[0]})
else:
filters.append({"year": {"$in": years}})
# Combine filters with $and if multiple
if len(filters) == 1:
where_clause = filters[0]
else:
where_clause = {"$and": filters}
# Query all papers and get distances
# Using collection.query() which returns papers sorted by distance
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=total_count, # Get all papers
include=["distances", "metadatas"],
where=where_clause
)
# Extract results (query returns nested lists)
paper_ids = results['ids'][0] if results.get('ids') else []
distances = results['distances'][0] if results.get('distances') else []
if not paper_ids:
raise EmbeddingsError("No results from collection query")
# Filter papers within distance threshold
matching_papers = []
for idx, (paper_id, distance) in enumerate(zip(paper_ids, distances)):
if distance <= distance_threshold:
# Get full paper details from database using uid
try:
paper_dict = get_paper_with_authors(database, paper_id)
paper_dict["distance"] = float(distance)
matching_papers.append(paper_dict)
except PaperFormattingError:
# Paper not found in database, skip it
logger.warning(f"Paper {paper_id} not found in database, skipping")
continue
else:
# Since results are sorted by distance, we can break early
break
return {
"query": query,
"query_embedding": query_embedding,
"distance": distance_threshold,
"papers": matching_papers,
"count": len(matching_papers),
}
except EmbeddingsError:
# Re-raise EmbeddingsError as-is
raise
except Exception as e:
logger.error(f"Error finding papers within distance: {e}", exc_info=True)
raise EmbeddingsError(f"Failed to find papers within distance: {str(e)}") from e