"""
Registry module for uploading and downloading data to/from OCI-compatible container registries.
This module provides functionality to push and pull abstracts-explorer data artifacts
(paper databases, embedding databases, clustering caches) to OCI-compatible registries
such as GitHub Container Registry (ghcr.io).
Artifacts are pushed and pulled using the `oras <https://oras-project.github.io/oras-py/>`_
Python SDK. Each artifact is tagged by conference (e.g. ``neurips``) or by conference
and year (e.g. ``neurips-2024``). A conference-only tag contains all available years
with each year stored as its own set of OCI layers (paper DB + embeddings +
clustering cache).
Examples
--------
Upload data for a specific year::
from abstracts_explorer.registry import RegistryClient
client = RegistryClient(
repository="ghcr.io/thawn/abstracts-data",
token="ghp_xxxx",
)
client.upload(conference="neurips", year=2024)
Upload all available years for a conference::
client.upload(conference="neurips")
Download data from the registry::
client.download(conference="neurips", year=2024)
List available tags::
tags = client.list_tags()
"""
import json
import logging
import os
import re
import shutil
import sqlite3
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
import oras.client
import oras.defaults
import oras.oci
import oras.provider
import requests
from packaging.version import InvalidVersion, Version
from abstracts_explorer._version import __version__
from abstracts_explorer.config import get_config
logger = logging.getLogger(__name__)
# Custom media types for abstracts-explorer artifacts
PAPER_DB_MEDIA_TYPE = "application/vnd.abstracts-explorer.paper-db.v1.tar+gzip"
EMBEDDING_DB_MEDIA_TYPE = "application/vnd.abstracts-explorer.embedding-db.v1.tar+gzip"
CLUSTERING_CACHE_MEDIA_TYPE = "application/vnd.abstracts-explorer.clustering-cache.v1.tar+gzip"
CONFIG_MEDIA_TYPE = "application/vnd.abstracts-explorer.config.v1+json"
[docs]
class RegistryError(Exception):
"""Exception raised for registry operation errors."""
pass
[docs]
class EmbeddingModelMismatchError(RegistryError):
"""
Raised when the embedding model in the local database does not match the model in the downloaded artifact.
Attributes
----------
local_model : str
Embedding model currently stored in the local database.
remote_model : str
Embedding model used by the downloaded artifact.
"""
[docs]
def __init__(self, local_model: str, remote_model: str) -> None:
self.local_model = local_model
self.remote_model = remote_model
super().__init__(
f"Embedding model mismatch: local database uses '{local_model}' "
f"but downloaded artifact uses '{remote_model}'. "
f"Cannot import data created with a different embedding model."
)
def _sanitize_str_for_oci_tag(value: str) -> str:
"""
Sanitize a string for use as an OCI tag component.
The value is lowercased and characters not in ``[a-z0-9._-]`` are
replaced with ``-``. OCI tags allow ``[a-zA-Z0-9_.-]``. In
particular the ``+`` local-version separator used by PEP 440
(e.g. ``1.2.3+g1a2b3c4``) is replaced with ``-``. Consecutive
hyphens are collapsed and leading/trailing hyphens are stripped.
Parameters
----------
value : str
String to sanitize (e.g. a model name or a PEP 440 version).
Returns
-------
str
Tag-safe string (e.g. ``text-embedding-ada-002`` or
``0.1.dev2-g2abcfb2a2``).
"""
safe = value.lower()
safe = re.sub(r"[^a-z0-9._-]", "-", safe)
# Collapse consecutive hyphens
safe = re.sub(r"-{2,}", "-", safe)
return safe.strip("-")
def _build_tag(
conference: str,
year: Optional[int] = None,
*,
embedding_model: str,
version: Optional[str] = None,
) -> str:
"""
Build an OCI tag from conference name, embedding model, version and optional year.
Parameters
----------
conference : str
Conference name.
year : int, optional
Conference year. When ``None``, the tag contains only the
conference name (e.g. ``neurips``).
embedding_model : str
Embedding model name. Appended to the tag after a ``_``
separator (e.g. ``neurips-2024_text-embedding-ada-002_1.0.0``).
version : str, optional
abstracts-explorer version string. When ``None``, the current
package version (``__version__``) is used. The version is
sanitized for OCI tag use and appended after a ``_`` separator
(e.g. ``neurips-2024_text-embedding-ada-002_0.1.0``).
Returns
-------
str
Tag string (e.g. ``neurips-2024_text-embedding-ada-002_0.1.0``).
"""
if version is None:
version = __version__
safe_name = conference.lower().replace(" ", "-").replace("/", "-").replace("@", "-")
if year is not None:
tag = f"{safe_name}-{year}"
else:
tag = safe_name
tag = f"{tag}_{_sanitize_str_for_oci_tag(embedding_model)}_{_sanitize_str_for_oci_tag(version)}"
return tag
def _parse_version_from_tag(tag: str) -> Optional[Version]:
"""
Extract and parse the version component from an OCI tag.
OCI tags in this project have the format
``{conference}[-{year}]_{model}_{version}``. The version is always the
last ``_``-separated component. Dots are preserved by the sanitization
step so a normal semantic-version string (e.g. ``0.4.1``) round-trips
unchanged.
OCI sanitization (``_sanitize_str_for_oci_tag``) replaces the ``+``
PEP 440 local-version separator with ``-``, so dev versions such as
``0.4.6.dev16+g7005b7837`` appear in tags as ``0.4.6.dev16-g7005b7837``.
This function recovers the original version by trying all ``-`` → ``+``
substitution positions until one produces a valid PEP 440 version.
Parameters
----------
tag : str
OCI tag string (without repository prefix), e.g.
``neurips-2024_text-embedding-ada-002_0.4.1`` or
``ml4ps-neurips-2022_model_0.4.6.dev16-g7005b7837``.
Returns
-------
packaging.version.Version or None
Parsed version, or ``None`` if the tag contains no ``_`` separator
or the version component cannot be parsed.
"""
if "_" not in tag:
return None
raw = tag.rsplit("_", 1)[-1]
# Fast path: standard release or pre-release without a local segment.
try:
return Version(raw)
except InvalidVersion:
pass
# OCI sanitization replaces '+' (PEP 440 local-version separator) with '-'.
# Try restoring '+' at each '-' position until we find a valid version.
parts = raw.split("-")
for i in range(1, len(parts)):
candidate = "-".join(parts[:i]) + "+" + "-".join(parts[i:])
try:
return Version(candidate)
except InvalidVersion:
continue
return None
[docs]
class RegistryClient:
"""
Client for pushing and pulling data artifacts to/from OCI-compatible registries.
Uses the `oras <https://oras-project.github.io/oras-py/>`_ Python SDK to
interact with OCI registries.
The smallest unit of upload/download is a **conference + year** combination.
Each artifact always contains the paper database, embeddings, and clustering
cache together to prevent inconsistent data.
When ``year`` is omitted, all available years for the conference are uploaded
or downloaded, with each year stored as its own pair of OCI layers.
Parameters
----------
repository : str
Full OCI repository path (e.g., ``ghcr.io/thawn/abstracts-data``).
token : str, optional
Authentication token (e.g., GitHub Personal Access Token).
If not provided, will try the ``GITHUB_TOKEN`` environment variable.
Raises
------
RegistryError
If the repository format is invalid.
Examples
--------
>>> client = RegistryClient("ghcr.io/thawn/abstracts-data", token="ghp_xxxx")
>>> client.list_tags()
['neurips-2024', 'iclr-2025']
"""
[docs]
def __init__(self, repository: str, token: Optional[str] = None):
parts = repository.split("/", 1)
if len(parts) < 2 or not parts[0] or not parts[1]:
raise RegistryError(
f"Invalid repository format: '{repository}'. "
"Expected format: 'registry/owner/name' (e.g., 'ghcr.io/thawn/abstracts-data')"
)
self.registry = parts[0]
self.name = parts[1]
self.repository = repository
self.token = token or os.environ.get("GITHUB_TOKEN", "")
# Create oras client
self._client = oras.client.OrasClient(hostname=self.registry, insecure=False)
if self.token:
self._client.login(
username="_token",
password=self.token,
hostname=self.registry,
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _get_years_for_conference(conference: str) -> List[int]:
"""
Return the distinct years available in the local database for *conference*.
Parameters
----------
conference : str
Conference name.
Returns
-------
list of int
Sorted list of years.
"""
from abstracts_explorer.database import DatabaseManager
with DatabaseManager() as db:
db.create_tables()
return db.get_years_for_conference(conference)
@staticmethod
def _get_embedding_model_database() -> Optional[str]:
"""
Return the embedding model stored in the local database.
Returns
-------
str or None
Embedding model name, or ``None`` if not set.
"""
from abstracts_explorer.database import DatabaseManager
with DatabaseManager() as db:
db.create_tables()
return db.get_embedding_model()
@staticmethod
def _is_conference_level_tag(tag: str) -> bool:
"""
Return ``True`` when *tag* is a conference-level tag (no year suffix).
OCI tags in this project have the format
``{conference}[-{year}]_{model}_{version}``. Year-specific tags
contain a ``-YYYY`` suffix in the base component (before the first
``_``), e.g. ``chi-2026_model_0.4.1``. Conference-level tags
omit the year, e.g. ``chi_model_0.4.1``.
Download operations should only use conference-level tags so that
each year is not imported twice (once from its dedicated year tag
and again from the combined conference tag).
Parameters
----------
tag : str
OCI tag string (without repository prefix).
Returns
-------
bool
``True`` if the tag does not carry a year suffix.
"""
base = tag.split("_", 1)[0]
parts = base.rsplit("-", 1)
if len(parts) != 2:
return True
suffix = parts[1]
# A year suffix is exactly 4 digits.
return not (suffix.isdigit() and len(suffix) == 4)
def _find_best_matching_tag(self, tag: str) -> str:
"""
Resolve *tag* to the best matching tag available in the registry.
OCI tags in this project have the format
``{conference}[-year]_{model}_{version}``. When the exact tag does not
exist in the registry (e.g. because the local package version differs
from the version used when the artifact was pushed), this method strips
the version suffix and looks for other tags that share the same
``{conference}[-year]_{model}`` prefix. The candidate with the
lexicographically highest version suffix is returned.
If the exact tag exists, it is returned unchanged. If listing tags
fails or no prefix-matching candidate is found, *tag* is returned
unchanged so that the caller can still attempt the operation and
produce an informative error message.
Parameters
----------
tag : str
OCI tag to resolve (without repository prefix, e.g.
``neurips-2024_my-model_0.4.2``).
Returns
-------
str
The resolved tag.
"""
try:
available_tags = self._client.get_tags(self.repository)
except Exception as exc:
logger.debug("Could not list tags for tag resolution: %s", exc)
return tag
if tag in available_tags:
return tag
# Strip the version suffix (the last '_'-separated component) and search
# for tags that share the same prefix. The version is always the last
# component, so rsplit("_", 1) correctly isolates it even when the model
# name itself contains underscores.
if "_" not in tag:
return tag
# first split off the version suffix, then iteratively strip components from the end of the prefix until we find candidates that match the start of the tag. This allows us to resolve to a tag with a different version and/or model, as long as the conference and year match.
candidates: List[str] = []
prefix: str = tag
maxsplit: int = 1
while candidates == [] and "_" in prefix:
prefix = prefix.rsplit("_", 1)[0]
candidates = [t for t in available_tags if t.rsplit("_", maxsplit=maxsplit)[0] == prefix]
maxsplit += 1
if not candidates:
return tag
# Return the candidate with the highest lexicographic version suffix.
# This works correctly for standard semver strings (e.g. "0.4.1" < "0.4.2").
resolved = max(candidates)
logger.debug("Tag '%s' not found; resolved to closest match '%s'", tag, resolved)
return resolved
def _get_manifest_embedding_model(self, target: str) -> Optional[str]:
"""
Retrieve the embedding model name from the OCI manifest for *target*.
Fetches the manifest from the registry and reads the
``com.abstracts-explorer.embedding-model`` label without downloading any
artifact data. Checks both the ``labels`` field (used by
abstracts-explorer when pushing) and the ``annotations`` field
(used by some alternative OCI implementations).
Parameters
----------
target : str
Full OCI reference including tag (e.g.
``ghcr.io/thawn/abstracts-data:neurips-2024_model_1.0.0``).
Returns
-------
str or None
The embedding model stored in the manifest, or ``None`` if the
manifest has no such label (e.g. legacy artifacts) or if the
manifest could not be fetched.
"""
try:
manifest = self._client.get_manifest(target)
labels = manifest.get("labels") or manifest.get("annotations") or {}
result = labels.get("com.abstracts-explorer.embedding-model")
return result if isinstance(result, str) else None
except Exception as exc:
logger.debug("Could not fetch manifest for %s: %s", target, exc)
return None
[docs]
@staticmethod
def clear_local_embedding_data() -> None:
"""
Clear all local embedding data — metadata, ChromaDB collection, and clustering cache.
This is a destructive operation that removes *all* embedding-related data from
the local databases so that data with a different embedding model can be imported.
Use with care.
After calling this method, the next download will import fresh data and establish
a new embedding model association in the local database.
"""
from abstracts_explorer.database import DatabaseManager
from abstracts_explorer.db_models import EmbeddingsMetadata, HierarchicalLabelCache
from abstracts_explorer.embeddings import EmbeddingsManager
from sqlalchemy import delete as sa_delete
# 1. Clear EmbeddingsMetadata and clustering/hierarchical caches from SQLite
with DatabaseManager() as db:
db.create_tables()
db._session.execute(sa_delete(EmbeddingsMetadata)) # type: ignore[union-attr]
db.clear_clustering_cache()
db._session.execute(sa_delete(HierarchicalLabelCache)) # type: ignore[union-attr]
db._session.commit() # type: ignore[union-attr]
# 2. Reset the ChromaDB collection
em = EmbeddingsManager()
em.create_collection(reset=True)
def _export_year(
self,
conference: str,
year: int,
temp_dir: Path,
progress: Callable[[str], None],
) -> Dict[str, Any]:
"""
Export paper DB, embeddings, and clustering cache for a single conference+year.
Returns a dict with ``paper_db_path``, ``embeddings_path``,
``clustering_cache_path``, ``paper_count``, ``embedding_count``,
and ``clustering_cache_count``.
"""
from abstracts_explorer.database import DatabaseManager
from abstracts_explorer.embeddings import EmbeddingsManager
# --- paper DB (without clustering cache — it goes into a separate layer) ---
progress(f"Exporting paper database for {conference}/{year}...")
paper_db_path = temp_dir / f"papers-{year}.db"
with DatabaseManager() as db:
db.create_tables()
paper_count = db.export_papers_to_sqlite(paper_db_path, conference, year)
if paper_count == 0:
raise RegistryError(f"No papers found for {conference}/{year}. Download the conference data first.")
progress(f" Exported {paper_count} papers")
# --- embeddings ---
progress(f"Exporting embeddings for {conference}/{year}...")
em = EmbeddingsManager()
embeddings_data = em.export_embeddings(conference, year)
embedding_count = len(embeddings_data.get("ids", []))
if embedding_count == 0:
raise RegistryError(
f"No embeddings found for {conference}/{year}."
" Create embeddings first with 'abstracts-explorer create-embeddings'."
)
embeddings_path = temp_dir / f"embeddings-{year}.json"
embeddings_path.write_text(json.dumps(embeddings_data))
progress(f" Exported {embedding_count} embeddings")
# --- clustering cache (separate layer) ---
progress(f"Exporting clustering cache for {conference}/{year}...")
with DatabaseManager() as db:
db.create_tables()
cache_data = db.export_clustering_cache_to_json(conference, year)
clustering_cache_count = len(cache_data.get("entries", []))
if clustering_cache_count == 0:
raise RegistryError(
f"No clustering cache found for {conference}/{year}."
" Generate the clustering cache first with 'abstracts-explorer clustering pre-generate'."
)
clustering_cache_path = temp_dir / f"clustering-{year}.json"
clustering_cache_path.write_text(json.dumps(cache_data))
progress(f" Exported {clustering_cache_count} clustering cache entries")
return {
"paper_db_path": paper_db_path,
"embeddings_path": embeddings_path,
"clustering_cache_path": clustering_cache_path,
"paper_count": paper_count,
"embedding_count": embedding_count,
"clustering_cache_count": clustering_cache_count,
}
@staticmethod
def _read_artifact_embedding_model(paper_db_file: Path) -> Optional[str]:
"""
Read the embedding model name from the ``embeddings_metadata`` table
in a downloaded artifact's paper DB.
Parameters
----------
paper_db_file : Path
Path to the artifact's SQLite paper database.
Returns
-------
str or None
The embedding model stored in the artifact, or ``None`` if
the table does not exist or is empty (legacy artifacts).
"""
try:
with sqlite3.connect(str(paper_db_file)) as conn:
row = conn.execute(
"SELECT embedding_model FROM embeddings_metadata ORDER BY updated_at DESC LIMIT 1"
).fetchone()
return row[0] if row else None
except (sqlite3.OperationalError, sqlite3.DatabaseError) as exc:
logger.debug(
"Could not read embedding model from artifact DB %s: %s: %s",
paper_db_file.name,
type(exc).__name__,
exc,
)
return None
@staticmethod
def _replace_artifact_embedding_model(paper_db_file: Path, new_model: str) -> None:
"""
Overwrite the embedding model in the artifact's paper DB so that
subsequent imports do not trigger model-consistency checks.
Parameters
----------
paper_db_file : Path
Path to the artifact's SQLite paper database.
new_model : str
The model name to write.
"""
try:
with sqlite3.connect(str(paper_db_file)) as conn:
conn.execute(
"UPDATE embeddings_metadata SET embedding_model = ? "
"WHERE rowid = (SELECT rowid FROM embeddings_metadata ORDER BY updated_at DESC LIMIT 1)",
(new_model,),
)
except (sqlite3.OperationalError, sqlite3.DatabaseError) as exc:
logger.debug(
"Could not update embedding model in artifact DB %s: %s: %s",
paper_db_file.name,
type(exc).__name__,
exc,
)
@staticmethod
def _check_embedding_model(
paper_db_file: Path,
embedding_model: str,
ignore_embedding_model_mismatch: bool,
progress: Callable[[str], None],
) -> None:
"""
Single authoritative embedding-model check for one conference/year import.
Reads the model stored in the artifact paper DB and compares it to
*embedding_model*. When the models differ:
* If *ignore_embedding_model_mismatch* is ``False``,
``EmbeddingModelMismatchError`` is raised.
* If *ignore_embedding_model_mismatch* is ``True``, the artifact
DB is patched in-place so that the downstream
``import_papers_from_sqlite()`` consistency check will not
trigger again.
Parameters
----------
paper_db_file : Path
Path to the artifact's SQLite paper database.
embedding_model : str
The configured/expected embedding model name.
ignore_embedding_model_mismatch : bool
When ``True``, overwrite the artifact model and continue.
progress : callable
Status-message callback.
Raises
------
EmbeddingModelMismatchError
When models differ and *ignore_embedding_model_mismatch* is ``False``.
"""
artifact_model = RegistryClient._read_artifact_embedding_model(paper_db_file)
if not artifact_model:
return # Legacy artifact without metadata — nothing to check
if _sanitize_str_for_oci_tag(artifact_model) == _sanitize_str_for_oci_tag(embedding_model):
return # Models match — nothing to do
if not ignore_embedding_model_mismatch:
raise EmbeddingModelMismatchError(local_model=embedding_model, remote_model=artifact_model)
# Models differ but the user explicitly asked to proceed.
progress(
f"⚠️ Embedding model mismatch ignored for {paper_db_file.name}:\n"
f" Configured model: '{embedding_model}'\n"
f" Artifact model: '{artifact_model}'\n"
f" Replacing artifact model with configured model."
)
RegistryClient._replace_artifact_embedding_model(paper_db_file, embedding_model)
def _import_year(
self,
conference: str,
year: int,
paper_db_file: Path,
embeddings_file: Path,
progress: Callable[[str], None],
embedding_model: Optional[str] = None,
ignore_embedding_model_mismatch: bool = False,
clustering_cache_file: Optional[Path] = None,
) -> Dict[str, Any]:
"""
Import paper DB, embeddings, and clustering cache for a single conference+year.
All three files (*paper_db_file*, *embeddings_file*, and
*clustering_cache_file*) must exist. If any import fails, any
already-imported data for this conference+year is rolled back to
prevent inconsistency between the paper DB and the embedding DB.
The embedding-model consistency check happens **here** — this is
the single authoritative location. When
*ignore_embedding_model_mismatch* is ``True`` and the models
differ, the artifact DB is patched before the import so that the
downstream ``import_papers_from_sqlite()`` check will not trigger
again.
Returns a dict with ``paper_count``, ``embedding_count``, and
``clustering_cache_count``.
Raises
------
EmbeddingModelMismatchError
If the artifact's embedding model differs from *embedding_model*
and *ignore_embedding_model_mismatch* is ``False``.
RegistryError
If any file is missing or an import step fails.
"""
from abstracts_explorer.database import DatabaseManager
from abstracts_explorer.embeddings import EmbeddingsManager
# --- pre-flight: all three files must exist ---
missing = []
if not paper_db_file.exists():
missing.append(f"paper DB ({paper_db_file.name})")
if not embeddings_file.exists():
missing.append(f"embeddings ({embeddings_file.name})")
if clustering_cache_file is None or not clustering_cache_file.exists():
missing.append(
f"clustering cache ({clustering_cache_file.name if clustering_cache_file else 'not provided'})"
)
if missing:
raise RegistryError(
f"Incomplete data for {conference}/{year}: missing {', '.join(missing)}. "
"Cannot import — paper DB, embeddings, and clustering cache must all be present."
)
# --- single authoritative embedding-model check ---
if embedding_model:
self._check_embedding_model(paper_db_file, embedding_model, ignore_embedding_model_mismatch, progress)
# --- import paper DB first ---
progress(f"Importing paper database for {conference}/{year}...")
try:
with DatabaseManager() as db:
db.create_tables()
paper_count = db.import_papers_from_sqlite(paper_db_file, conference, year)
except Exception as db_err:
from abstracts_explorer.database import EmbeddingModelConflictError
if isinstance(db_err, EmbeddingModelConflictError):
raise EmbeddingModelMismatchError(db_err.local_model, db_err.remote_model) from db_err
raise RegistryError(f"Paper DB import failed for {conference}/{year}: {db_err}") from db_err
progress(f" Imported {paper_count} papers")
# --- import embeddings; rollback paper DB on failure ---
try:
progress(f"Importing embeddings for {conference}/{year}...")
embeddings_data = json.loads(embeddings_file.read_text())
em = EmbeddingsManager()
embedding_count = em.import_embeddings(embeddings_data, conference, year)
progress(f" Imported {embedding_count} embeddings")
except Exception as embed_err:
# Roll back paper DB import so both stay consistent
progress(f" Embedding import failed — rolling back paper DB for {conference}/{year}...")
try:
from sqlalchemy import and_ as sa_and
from sqlalchemy import delete as sa_delete
from abstracts_explorer.db_models import Paper
with DatabaseManager() as db:
db.create_tables()
db._session.execute( # type: ignore[union-attr]
sa_delete(Paper).where(sa_and(Paper.conference == conference, Paper.year == year))
)
db._session.commit() # type: ignore[union-attr]
except Exception:
logger.warning("Failed to roll back paper DB import after embedding failure", exc_info=True)
raise RegistryError(
f"Embedding import failed for {conference}/{year}: {embed_err}. "
"Paper DB changes have been rolled back."
) from embed_err
# --- import clustering cache from separate layer ---
clustering_cache_count = 0
progress(f"Importing clustering cache for {conference}/{year}...")
try:
cache_data = json.loads(clustering_cache_file.read_text()) # type: ignore[union-attr]
with DatabaseManager() as db:
db.create_tables()
clustering_cache_count = db.import_clustering_cache_from_json(
cache_data,
conference,
year,
overwrite_embedding_model=embedding_model if embedding_model else None,
)
progress(f" Imported {clustering_cache_count} clustering cache entries")
except Exception as cache_err:
raise RegistryError(f"Clustering cache import failed for {conference}/{year}: {cache_err}") from cache_err
return {
"paper_count": paper_count,
"embedding_count": embedding_count,
"clustering_cache_count": clustering_cache_count,
}
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def _push_tag(
self,
conference: str,
years_in_tag: List[int],
files: List[str],
config_path: Path,
tag: str,
embedding_model: str,
paper_count: int,
embedding_count: int,
progress: Callable[[str], None],
) -> None:
"""Push a single OCI tag with the given files.
Parameters
----------
conference : str
Conference name stored in manifest annotations.
years_in_tag : list of int
Years whose data files are included in this tag.
files : list of str
Absolute paths to the blob files to push.
config_path : Path
Path to the JSON config blob.
tag : str
OCI tag string (without repository prefix).
embedding_model : str
Embedding model name stored in manifest annotations.
paper_count : int
Total paper count stored in manifest annotations.
embedding_count : int
Total embedding count stored in manifest annotations.
progress : callable
Progress message callback.
"""
target = f"{self.repository}:{tag}"
manifest_annotations = {
"com.abstracts-explorer.version": __version__,
"com.abstracts-explorer.conference": conference,
"com.abstracts-explorer.years": ",".join(str(y) for y in years_in_tag),
"com.abstracts-explorer.paper-count": str(paper_count),
"com.abstracts-explorer.embedding-count": str(embedding_count),
"com.abstracts-explorer.embedding-model": embedding_model,
}
self._client.push(
target=target,
files=files,
manifest_config=str(config_path),
manifest_annotations=manifest_annotations,
disable_path_validation=True,
)
progress(f"Successfully pushed {target}")
[docs]
def upload(
self,
conference: str,
year: Optional[int] = None,
tag: Optional[str] = None,
progress_callback: Optional[Callable[[str], None]] = None,
) -> Dict[str, Any]:
"""
Upload data for a conference (and optionally a specific year) to the registry.
Packages the paper database, embeddings, and clustering cache as OCI
layers and pushes them together. All three must be present for every
year; an error is raised if any data is missing.
When *year* is not ``None``, a single per-year tag is pushed
(e.g. ``neurips-2024_model``).
When *year* is ``None``, every year available locally is first pushed as
its own individual tag (e.g. ``neurips-2024_model``, ``neurips-2025_model``)
and then an all-years summary tag (e.g. ``neurips_model``) is pushed
containing all years' files as layers. Because OCI blobs are
content-addressed, the registry deduplicates the files — no data is
actually stored twice.
Parameters
----------
conference : str
Conference name (e.g. ``neurips``).
year : int, optional
Conference year (e.g. ``2024``). When ``None``, all available
years are uploaded.
tag : str, optional
Custom tag. If ``None``, derived from embedding model, conference and year.
progress_callback : callable, optional
Function called with status messages during upload.
Returns
-------
dict
Upload summary with paper count, embedding count, years, tag, and
(when multiple years) ``year_tags`` listing the per-year tags pushed.
Raises
------
RegistryError
If upload fails or required data is missing.
"""
def _progress(msg: str) -> None:
if progress_callback:
progress_callback(msg)
logger.info(msg)
# --- Determine embedding model (needed for auto-tag) ---
embedding_model = self._get_embedding_model_database()
if not embedding_model:
raise RegistryError(
"No embedding model found in local database. "
"Create embeddings first with 'abstracts-explorer create-embeddings'."
)
# Determine which years to upload
if year is not None:
years = [year]
else:
years = self._get_years_for_conference(conference)
if not years:
raise RegistryError(
f"No data found for conference '{conference}'. Download the conference data first."
)
_progress(f"Found years for {conference}: {years}")
# Build the target tag (for single-year or all-years summary)
if tag is None:
tag = _build_tag(conference, year, embedding_model=embedding_model)
temp_dir = Path(tempfile.mkdtemp())
try:
all_files: List[str] = []
total_papers = 0
total_embeddings = 0
total_clustering_cache = 0
year_tags: List[str] = []
for yr in years:
yr_data = self._export_year(conference, yr, temp_dir, _progress)
yr_files = [str(yr_data["paper_db_path"]), str(yr_data["embeddings_path"])]
yr_files.append(str(yr_data["clustering_cache_path"]))
all_files.extend(yr_files)
total_papers += yr_data["paper_count"]
total_embeddings += yr_data["embedding_count"]
total_clustering_cache += yr_data["clustering_cache_count"]
# When uploading multiple years, push each year as its own tag first
if year is None:
yr_tag = _build_tag(conference, yr, embedding_model=embedding_model)
year_tags.append(yr_tag)
# Write per-year config
yr_config_data = {
"version": __version__,
"conference": conference,
"years": [yr],
"paper_count": yr_data["paper_count"],
"embedding_count": yr_data["embedding_count"],
"clustering_cache_count": yr_data["clustering_cache_count"],
"embedding_model": embedding_model,
}
yr_config_path = temp_dir / f"config-{yr}.json"
yr_config_path.write_text(json.dumps(yr_config_data, indent=2))
_progress(f"Uploading {yr_tag}...")
self._push_tag(
conference=conference,
years_in_tag=[yr],
files=yr_files,
config_path=yr_config_path,
tag=yr_tag,
embedding_model=embedding_model,
paper_count=yr_data["paper_count"],
embedding_count=yr_data["embedding_count"],
progress=_progress,
)
# --- Build all-years (or single-year) config metadata ---
config_data = {
"version": __version__,
"conference": conference,
"years": years,
"paper_count": total_papers,
"embedding_count": total_embeddings,
"clustering_cache_count": total_clustering_cache,
"embedding_model": embedding_model,
}
config_path = temp_dir / "config.json"
config_path.write_text(json.dumps(config_data, indent=2))
# --- Push the final (single-year or all-years summary) tag ---
_progress(f"Uploading {tag}...")
self._push_tag(
conference=conference,
years_in_tag=years,
files=all_files,
config_path=config_path,
tag=tag,
embedding_model=embedding_model,
paper_count=total_papers,
embedding_count=total_embeddings,
progress=_progress,
)
summary: Dict[str, Any] = {
"tag": tag,
"conference": conference,
"years": years,
"paper_count": total_papers,
"embedding_count": total_embeddings,
"clustering_cache_count": total_clustering_cache,
}
if year_tags:
summary["year_tags"] = year_tags
return summary
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
[docs]
def download(
self,
conference: str,
year: Optional[int] = None,
tag: Optional[str] = None,
embedding_model: Optional[str] = None,
progress_callback: Optional[Callable[[str], None]] = None,
ignore_embedding_model_mismatch: bool = False,
) -> Dict[str, Any]:
"""
Download data for a conference (and optionally a specific year) from the registry.
Pulls the paper database, embeddings, and clustering cache and
replaces existing local data for the specified conference and year(s).
When *year* is ``None``, all years contained in the artifact are
downloaded.
Parameters
----------
conference : str
Conference name (e.g. ``neurips``).
year : int, optional
Conference year (e.g. ``2024``). When ``None``, all years
in the artifact are imported.
tag : str, optional
Custom tag. If ``None``, derived from embedding model, conference and year.
embedding_model : str, optional
Embedding model name used for tag derivation. When ``None``
and *tag* is also ``None``, the model is read from the
``EMBEDDING_MODEL`` configuration.
A ``RegistryError`` is raised if the model cannot be determined.
progress_callback : callable, optional
Function called with status messages during download.
ignore_embedding_model_mismatch : bool, optional
When ``True``, proceed with the download even if the artifact's embedding model
differs from the configured model. After a successful import the local embedding
model metadata is updated to match *embedding_model*. Only use this option when
the mismatch is caused by the same model having different names on different
backends (e.g. LM Studio vs. Ollama). Default is ``False``.
Returns
-------
dict
Download summary with paper count and embedding count.
Raises
------
EmbeddingModelMismatchError
If the artifact's embedding model differs from *embedding_model* and
*ignore_embedding_model_mismatch* is ``False``.
RegistryError
If download fails or the embedding model cannot be determined.
"""
if embedding_model is None:
embedding_model = get_config().embedding_model
if not embedding_model:
raise RegistryError(
"No embedding model specified and none found in the configuration. "
"Use --embedding-model to specify the model name."
)
if tag is None:
tag = _build_tag(conference, year, embedding_model=embedding_model)
def _progress(msg: str) -> None:
if progress_callback:
progress_callback(msg)
logger.info(msg)
target = f"{self.repository}:{tag}"
# --- 0a. Resolve tag: find the best matching tag in the registry ---
# The locally-built tag includes the current package version, which may
# differ from the version used when the artifact was pushed. Resolve to
# the closest matching tag so that both the manifest check and the pull
# use a tag that actually exists.
resolved_tag = self._find_best_matching_tag(tag)
if resolved_tag != tag:
_progress(f"Tag '{tag}' not found in registry; using closest match '{resolved_tag}'")
tag = resolved_tag
target = f"{self.repository}:{tag}"
# --- 0b. Pre-download: check embedding model from manifest labels ---
# Fail fast before pulling data when the artifact's model does not
# match and the user has not opted to ignore the mismatch.
manifest_embedding_model = self._get_manifest_embedding_model(target)
if manifest_embedding_model and embedding_model:
if _sanitize_str_for_oci_tag(manifest_embedding_model) != _sanitize_str_for_oci_tag(embedding_model):
if not ignore_embedding_model_mismatch:
raise EmbeddingModelMismatchError(
local_model=embedding_model,
remote_model=manifest_embedding_model,
)
temp_dir = Path(tempfile.mkdtemp())
try:
# --- 1. Pull from oras ---
_progress(f"Pulling {target}...")
pulled_files = self._client.pull(target=target, outdir=str(temp_dir))
_progress(f"Downloaded {len(pulled_files)} files")
# Read config metadata if available
metadata: Dict[str, Any] = {}
for fpath in pulled_files:
p = Path(fpath)
if p.name == "config.json":
metadata = json.loads(p.read_text())
_progress(f"Artifact version: {metadata.get('version', 'unknown')}")
break
# --- 2. Group files by year ---
# Files are named papers-YYYY.db, embeddings-YYYY.json, and clustering-YYYY.json
year_files: Dict[int, Dict[str, Path]] = {}
for fpath in pulled_files:
p = Path(fpath)
name = p.name
if name.startswith("papers-") and name.endswith(".db"):
try:
yr = int(name[len("papers-") : -len(".db")])
except ValueError:
logger.warning(f"Skipping file with invalid year format: {name}")
continue
year_files.setdefault(yr, {})["paper_db"] = p
elif name.startswith("embeddings-") and name.endswith(".json"):
try:
yr = int(name[len("embeddings-") : -len(".json")])
except ValueError:
logger.warning(f"Skipping file with invalid year format: {name}")
continue
year_files.setdefault(yr, {})["embeddings"] = p
elif name.startswith("clustering-") and name.endswith(".json"):
try:
yr = int(name[len("clustering-") : -len(".json")])
except ValueError:
logger.warning(f"Skipping file with invalid year format: {name}")
continue
year_files.setdefault(yr, {})["clustering_cache"] = p
# If user requested a specific year, filter
if year is not None:
year_files = {yr: files for yr, files in year_files.items() if yr == year}
# --- 3. Validate completeness ---
for yr in sorted(year_files.keys()):
files = year_files[yr]
missing = []
if not files.get("paper_db"):
missing.append("paper DB")
if not files.get("embeddings"):
missing.append("embeddings")
if not files.get("clustering_cache"):
missing.append("clustering cache")
if missing:
raise RegistryError(
f"Incomplete data for {conference}/{yr}: missing {', '.join(missing)}. "
"Cannot import — paper DB, embeddings, and clustering cache must all be present."
)
# --- 4. Import each year ---
total_papers = 0
total_embeddings = 0
total_clustering_cache = 0
imported_years: List[int] = []
for yr in sorted(year_files.keys()):
files = year_files[yr]
paper_db = files["paper_db"]
embeddings = files["embeddings"]
clustering_cache = files["clustering_cache"]
result = self._import_year(
conference,
yr,
paper_db,
embeddings,
_progress,
embedding_model=embedding_model,
ignore_embedding_model_mismatch=ignore_embedding_model_mismatch,
clustering_cache_file=clustering_cache,
)
total_papers += result["paper_count"]
total_embeddings += result["embedding_count"]
total_clustering_cache += result["clustering_cache_count"]
imported_years.append(yr)
if not imported_years:
_progress("Warning: No data found in artifact to import")
_progress("Download complete!")
return {
"tag": tag,
"conference": conference,
"years": imported_years,
"paper_count": total_papers,
"embedding_count": total_embeddings,
"clustering_cache_count": total_clustering_cache,
"metadata": metadata,
}
except RegistryError:
raise
except Exception as e:
raise RegistryError(f"Failed to download: {e}") from e
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
[docs]
def upload_all(
self,
progress_callback: Optional[Callable[[str], None]] = None,
) -> List[Dict[str, Any]]:
"""
Upload data for **all** conferences available locally.
Each conference is uploaded as a separate OCI artifact with a
conference-only tag (containing all years for that conference).
Parameters
----------
progress_callback : callable, optional
Function called with status messages during upload.
Returns
-------
list of dict
Upload summaries, one per conference.
Raises
------
RegistryError
If no conferences are found or any upload fails.
"""
from abstracts_explorer.database import DatabaseManager
with DatabaseManager() as db:
db.create_tables()
filters = db.get_conferences()
conferences = sorted(filters)
if not conferences:
raise RegistryError("No conference data found in local database.")
def _progress(msg: str) -> None:
if progress_callback:
progress_callback(msg)
logger.info(msg)
_progress(f"Found {len(conferences)} conference(s): {conferences}")
summaries: List[Dict[str, Any]] = []
for conf in conferences:
_progress(f"\n--- Uploading {conf} ---")
summary = self.upload(conference=conf, progress_callback=progress_callback)
summaries.append(summary)
return summaries
[docs]
def download_all(
self,
progress_callback: Optional[Callable[[str], None]] = None,
ignore_embedding_model_mismatch: bool = False,
) -> List[Dict[str, Any]]:
"""
Download data for **all** conference tags in the registry.
Lists available tags and downloads every conference-level tag
(i.e. tags without a year suffix).
Parameters
----------
progress_callback : callable, optional
Function called with status messages during download.
ignore_embedding_model_mismatch : bool, optional
If True, ignore embedding model mismatches during download.
Returns
-------
list of dict
Download summaries, one per conference tag.
Raises
------
RegistryError
If no tags are found or any download fails.
"""
tags = self.list_tags()
if not tags:
raise RegistryError("No tags found in registry.")
# Only download conference-level tags (e.g. "chi_model_0.4.1").
# Year-specific tags (e.g. "chi-2026_model_0.4.1") are subsets of the
# conference tag and would cause each year to be imported twice.
conference_tags = [t for t in tags if self._is_conference_level_tag(t)]
if not conference_tags:
raise RegistryError("No conference-level tags found in registry.")
def _progress(msg: str) -> None:
if progress_callback:
progress_callback(msg)
logger.info(msg)
_progress(
f"Found {len(conference_tags)} conference tag(s) in registry (skipping {len(tags) - len(conference_tags)} year-specific tag(s))"
)
summaries: List[Dict[str, Any]] = []
for tag in sorted(conference_tags):
_progress(f"\n--- Downloading {tag} ---")
# Read manifest annotations to derive conference/year
try:
info = self.get_artifact_info(tag)
annotations = info.get("annotations", {})
conf = annotations.get("com.abstracts-explorer.conference", "")
years_str = annotations.get("com.abstracts-explorer.years", "")
except RegistryError:
conf = ""
years_str = ""
if not conf:
# Fallback: derive conference from tag by splitting at underscore
# (tag format: "conference[-year]_model" or legacy "conference[-year]")
base = tag.split("_", 1)[0]
parts = base.rsplit("-", 1)
if len(parts) == 2 and parts[1].isdigit():
conf = parts[0]
else:
conf = base
# Determine year from annotations
yr: Optional[int] = None
if years_str:
year_vals = [int(y) for y in years_str.split(",") if y.strip().isdigit()]
if len(year_vals) == 1:
yr = year_vals[0]
# If multiple years, leave yr=None to download all
else:
# Fallback: derive year from tag
base = tag.split("_", 1)[0]
parts = base.rsplit("-", 1)
if len(parts) == 2 and parts[1].isdigit():
yr = int(parts[1])
summary = self.download(
conference=conf,
year=yr,
tag=tag,
progress_callback=progress_callback,
ignore_embedding_model_mismatch=ignore_embedding_model_mismatch,
)
summaries.append(summary)
return summaries
[docs]
def get_artifact_info(self, tag: str) -> Dict[str, Any]:
"""
Get metadata about a specific artifact tag.
Parameters
----------
tag : str
Tag to inspect.
Returns
-------
dict
Artifact metadata including version, conference, year, and counts.
Raises
------
RegistryError
If the tag is not found or cannot be read.
"""
try:
target = f"{self.repository}:{tag}"
manifest = self._client.get_manifest(target)
info: Dict[str, Any] = {
"tag": tag,
"annotations": manifest.get("annotations", {}),
"layers": [],
}
for layer in manifest.get("layers", []):
layer_info = {
"media_type": layer.get("mediaType", ""),
"size": layer.get("size", 0),
"annotations": layer.get("annotations", {}),
}
info["layers"].append(layer_info)
return info
except Exception as e:
raise RegistryError(f"Failed to get artifact info for '{tag}': {e}") from e
# ------------------------------------------------------------------
# GitHub Packages API helpers (deletion)
# ------------------------------------------------------------------
def _github_api_headers(self) -> Dict[str, str]:
"""Return HTTP headers for authenticated GitHub API requests."""
headers: Dict[str, str] = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
return headers
def _list_github_package_versions(self) -> List[Dict[str, Any]]:
"""
List all versions of the GHCR package via the GitHub Packages API.
Returns
-------
list of dict
Each entry contains at minimum ``id``, ``name`` (digest), and
``metadata.container.tags`` (list of OCI tags for that version).
Raises
------
RegistryError
If the API call fails or the registry is not hosted on ``ghcr.io``.
"""
if not self.registry.lower() == "ghcr.io":
raise RegistryError(
f"Deletion via the GitHub Packages API is only supported for 'ghcr.io' registries, "
f"not '{self.registry}'."
)
# self.name is "{owner}/{package_name}" (everything after "ghcr.io/")
name_parts = self.name.split("/", 1)
if len(name_parts) != 2 or not name_parts[0] or not name_parts[1]:
raise RegistryError(
f"Cannot determine owner and package name from repository '{self.repository}'. "
"Expected format: 'ghcr.io/{{owner}}/{{package-name}}'."
)
owner, package_name = name_parts
# URL-encode the package name (slashes → %2F)
package_name_encoded = package_name.replace("/", "%2F")
headers = self._github_api_headers()
all_versions: List[Dict[str, Any]] = []
page = 1
while True:
url = (
f"https://api.github.com/users/{owner}/packages/container"
f"/{package_name_encoded}/versions?per_page=100&page={page}"
)
try:
response = requests.get(url, headers=headers, timeout=30)
except requests.RequestException as e:
raise RegistryError(f"GitHub API request failed: {e}") from e
if response.status_code == 401:
raise RegistryError("GitHub API authentication failed. Check that your token is valid.")
if response.status_code == 403:
raise RegistryError("GitHub API access forbidden. Ensure your token has the 'delete:packages' scope.")
if response.status_code == 404:
# Try org-level endpoint as fallback
url_org = (
f"https://api.github.com/orgs/{owner}/packages/container"
f"/{package_name_encoded}/versions?per_page=100&page={page}"
)
try:
response = requests.get(url_org, headers=headers, timeout=30)
except requests.RequestException as e:
raise RegistryError(f"GitHub API request failed: {e}") from e
if response.status_code != 200:
raise RegistryError(
f"GitHub API returned HTTP {response.status_code} for package '{package_name}' "
f"under owner/org '{owner}'. Verify the repository path and token permissions."
)
if response.status_code != 200:
raise RegistryError(f"GitHub API returned HTTP {response.status_code}: {response.text[:200]}")
page_data = response.json()
if not page_data:
break
all_versions.extend(page_data)
if len(page_data) < 100:
break
page += 1
return all_versions
def _delete_github_package_version(self, owner: str, package_name: str, version_id: int) -> None:
"""
Delete a single package version via the GitHub Packages API.
Parameters
----------
owner : str
GitHub username or organisation name.
package_name : str
Package name (e.g. ``abstracts-data``).
version_id : int
Numeric version ID returned by the list-versions endpoint.
Raises
------
RegistryError
If the deletion API call fails.
"""
package_name_encoded = package_name.replace("/", "%2F")
headers = self._github_api_headers()
# Try user endpoint first, fall back to org endpoint on 404.
for endpoint in ("users", "orgs"):
url = (
f"https://api.github.com/{endpoint}/{owner}/packages/container"
f"/{package_name_encoded}/versions/{version_id}"
)
try:
response = requests.delete(url, headers=headers, timeout=30)
except requests.RequestException as e:
raise RegistryError(f"GitHub API request failed: {e}") from e
if response.status_code == 404 and endpoint == "users":
# Package might be org-owned — try org endpoint
continue
if response.status_code in (204, 200):
return
raise RegistryError(
f"Failed to delete package version {version_id}: "
f"HTTP {response.status_code} — {response.text[:200]}"
)
[docs]
def delete_old_versions(
self,
below_version: str,
conference: Optional[str] = None,
dry_run: bool = False,
progress_callback: Optional[Callable[[str], None]] = None,
) -> List[Dict[str, Any]]:
"""
Delete registry package versions whose tag version is older than *below_version*.
Uses the `GitHub Packages API
<https://docs.github.com/en/rest/packages/packages>`_ to list and
delete container image versions. Only versions that carry at least
one OCI tag matching the abstracts-explorer tag format
(``{conference}[-{year}]_{model}_{version}``) are considered;
untagged (dangling) versions are left untouched.
Parameters
----------
below_version : str
Threshold version string (PEP 440). Versions **strictly older**
than this value are deleted. Example: ``"0.4.0"`` deletes all
versions tagged with a version < 0.4.0.
conference : str, optional
When provided, only tags whose base component starts with
*conference* (case-insensitive) are examined. Tags for other
conferences are ignored.
dry_run : bool, optional
When ``True``, log which versions *would* be deleted but perform
no actual deletions (default: ``False``).
progress_callback : callable, optional
Function called with status messages during the operation.
Returns
-------
list of dict
One entry per deleted (or, in dry-run mode, would-be-deleted)
version. Each dict contains ``version_id``, ``tags``, and
``version``.
Raises
------
RegistryError
If the registry is not hosted on ``ghcr.io``, if the GitHub API
call fails, or if *below_version* cannot be parsed.
ValueError
If *below_version* is not a valid PEP 440 version string.
"""
try:
threshold = Version(below_version)
except InvalidVersion as e:
raise ValueError(f"Invalid version string '{below_version}': {e}") from e
def _progress(msg: str) -> None:
if progress_callback:
progress_callback(msg)
logger.info(msg)
_progress(f"Fetching package versions from GitHub Packages API ({self.repository}) …")
all_versions = self._list_github_package_versions()
_progress(f"Found {len(all_versions)} package version(s) in registry.")
# Determine owner / package name for deletion calls
name_parts = self.name.split("/", 1)
owner, package_name = name_parts[0], name_parts[1]
deleted: List[Dict[str, Any]] = []
for pkg_version in all_versions:
version_id: int = pkg_version["id"]
tags: List[str] = pkg_version.get("metadata", {}).get("container", {}).get("tags", [])
if not tags:
# Skip untagged (dangling) versions
continue
# A package version may carry multiple tags; check whether *any*
# of them belongs to the requested conference and is old enough.
should_delete = False
matched_tags: List[Tuple[str, Version]] = []
for tag in tags:
# Optional conference filter
if conference is not None:
base = tag.split("_", 1)[0]
conf_prefix = _sanitize_str_for_oci_tag(conference)
# Accept both exact match (e.g. "neurips") and year-specific
# (e.g. "neurips-2024").
if not (base == conf_prefix or base.startswith(conf_prefix + "-")):
continue
tag_version = _parse_version_from_tag(tag)
if tag_version is not None and tag_version < threshold:
matched_tags.append((tag, tag_version))
should_delete = True
if not should_delete:
continue
tag_summary = ", ".join(f"{t} (v{v})" for t, v in matched_tags)
if dry_run:
_progress(f" [dry-run] Would delete version {version_id}: {tag_summary}")
else:
_progress(f" Deleting version {version_id}: {tag_summary} …")
self._delete_github_package_version(owner, package_name, version_id)
_progress(f" ✓ Deleted version {version_id}.")
deleted.append(
{
"version_id": version_id,
"tags": tags,
"version": str(matched_tags[0][1]) if matched_tags else None,
}
)
action = "would be deleted" if dry_run else "deleted"
_progress(f"\nDone. {len(deleted)} version(s) {action}.")
return deleted