Source code for model_registry.core

"""Client for the model registry."""

from __future__ import annotations

from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import TypeVar, cast, overload

from mr_openapi import (
    ApiClient,
    Configuration,
    ModelRegistryServiceApi,
)
from mr_openapi import (
    exceptions as mr_exceptions,
)

from ._utils import required_args
from .types import (
    Artifact,
    Experiment,
    ExperimentRun,
    ExperimentRunArtifact,
    ListOptions,
    ModelArtifact,
    ModelVersion,
    RegisteredModel,
)

ArtifactT = TypeVar("ArtifactT", bound=Artifact)


[docs] @dataclass class ModelRegistryAPIClient: """Model registry API.""" config: Configuration
[docs] @classmethod def secure_connection( cls, server_address: str, port: int = 443, *, user_token: str, custom_ca: str | None = None, ) -> ModelRegistryAPIClient: """Constructor. Args: server_address: Server address. port: Server port. Defaults to 443. Keyword Args: user_token: The PEM-encoded user token as a string. custom_ca: The path to a PEM- """ return cls( Configuration( f"{server_address}:{port}", access_token=user_token, ssl_ca_cert=custom_ca, ) )
[docs] @classmethod def insecure_connection( cls, server_address: str, port: int, user_token: str | None = None, ) -> ModelRegistryAPIClient: """Constructor. Args: server_address: Server address. port: Server port. user_token: The PEM-encoded user token as a string. """ config = Configuration( host=f"{server_address}:{port}", access_token=user_token, verify_ssl=False, ) return cls(config)
[docs] @asynccontextmanager async def get_client(self) -> AsyncIterator[ModelRegistryServiceApi]: """Get a client for the model registry.""" api_client = ApiClient(self.config) client = ModelRegistryServiceApi(api_client) try: yield client finally: await api_client.close()
[docs] async def upsert_registered_model(self, registered_model: RegisteredModel) -> RegisteredModel: """Upsert a registered model. Updates or creates a registered model on the server. Args: registered_model: Registered model. Returns: New registered model. """ async with self.get_client() as client: if registered_model.id: rm = await client.update_registered_model(registered_model.id, registered_model.update()) else: rm = await client.create_registered_model(registered_model.create()) return RegisteredModel.from_basemodel(rm)
[docs] async def get_registered_model_by_id(self, id: str) -> RegisteredModel | None: """Fetch a registered model by its ID. Args: id: Registered model ID. Returns: Registered model. """ async with self.get_client() as client: try: rm = await client.get_registered_model(id) except mr_exceptions.NotFoundException: return None return RegisteredModel.from_basemodel(rm)
@overload async def get_registered_model_by_params(self, name: str): ... @overload async def get_registered_model_by_params(self, *, external_id: str): ...
[docs] @required_args(("name",), ("external_id",)) async def get_registered_model_by_params( self, name: str | None = None, external_id: str | None = None ) -> RegisteredModel | None: """Fetch a registered model by its name or external ID. Args: name: Registered model name. external_id: Registered model external ID. Returns: Registered model. """ async with self.get_client() as client: try: rm = await client.find_registered_model(name=name, external_id=external_id) except mr_exceptions.NotFoundException: return None return RegisteredModel.from_basemodel(rm)
[docs] async def get_registered_models(self, options: ListOptions | None = None) -> list[RegisteredModel]: """Fetch registered models. Args: options: Options for listing registered models. Returns: Registered models. """ async with self.get_client() as client: rm_list = await client.get_registered_models(**(options or ListOptions()).as_options()) if options: options.next_page_token = rm_list.next_page_token return [RegisteredModel.from_basemodel(rm) for rm in rm_list.items or []]
[docs] async def upsert_model_version( self, model_version: ModelVersion, registered_model_id: str | None = None ) -> ModelVersion: """Upsert a model version. Updates or creates a model version on the server. Args: model_version: Model version to upsert. registered_model_id: ID of the registered model this version will be associated to. Can be None when updating an existing model version. Returns: New model version. """ async with self.get_client() as client: if model_version.id: mv = await client.update_model_version(model_version.id, model_version.update()) elif registered_model_id: mv = await client.create_model_version(model_version.create(registered_model_id=registered_model_id)) else: msg = f"Registered model ID required for creating a new model version: {model_version}" raise ValueError(msg) return ModelVersion.from_basemodel(mv)
[docs] async def get_model_version_by_id(self, model_version_id: str) -> ModelVersion | None: """Fetch a model version by its ID. Args: model_version_id: Model version ID. Returns: Model version. """ async with self.get_client() as client: try: mv = await client.get_model_version(model_version_id) except mr_exceptions.NotFoundException: return None return ModelVersion.from_basemodel(mv)
[docs] async def get_model_versions( self, registered_model_id: str, options: ListOptions | None = None ) -> list[ModelVersion]: """Fetch model versions by registered model ID. Args: registered_model_id: Registered model ID. options: Options for listing model versions. Returns: Model versions. """ async with self.get_client() as client: mv_list = await client.get_registered_model_versions( registered_model_id, **(options or ListOptions()).as_options() ) if options: options.next_page_token = mv_list.next_page_token return [ModelVersion.from_basemodel(mv) for mv in mv_list.items or []]
@overload async def get_model_version_by_params(self, registered_model_id: str, name: str): ... @overload async def get_model_version_by_params(self, *, external_id: str): ...
[docs] @required_args( ( "registered_model_id", "name", ), ("external_id",), ) async def get_model_version_by_params( self, registered_model_id: str | None = None, name: str | None = None, external_id: str | None = None, ) -> ModelVersion | None: """Fetch a model version by associated parameters. Either fetches by using external ID or by using registered model ID and version name. Args: registered_model_id: Registered model ID. name: Model version. external_id: Model version external ID. Returns: Model version. """ async with self.get_client() as client: try: mv = await client.find_model_version( name=name, external_id=external_id, parent_resource_id=registered_model_id, ) except mr_exceptions.NotFoundException: return None return ModelVersion.from_basemodel(mv)
[docs] async def upsert_model_artifact(self, model_artifact: ModelArtifact) -> ModelArtifact: """Upsert a model artifact. Updates or creates a model artifact on the server. Args: model_artifact: Model artifact to upsert. model_version_id: ID of the model version this artifact will be associated to. Returns: New model artifact. """ async with self.get_client() as client: if not model_artifact.id: ma = await client.create_model_artifact(model_artifact.create()) else: ma = await client.update_model_artifact(model_artifact.id, model_artifact.update()) return ModelArtifact.from_basemodel(ma)
[docs] async def upsert_model_version_artifact(self, artifact: ArtifactT, model_version_id: str) -> ArtifactT: """Creates a model version artifact. Creates a model version artifact on the server. Args: artifact: Model version artifact to upsert. model_version_id: ID of the model version this artifact will be associated to. Returns: New model version artifact. """ async with self.get_client() as client: return cast( ArtifactT, Artifact.validate_artifact( await client.upsert_model_version_artifact(model_version_id, artifact.wrap()) ), )
[docs] async def get_model_artifact_by_id(self, id: str) -> ModelArtifact | None: """Fetch a model artifact by its ID. Args: id: Model artifact ID. Returns: Model artifact. """ async with self.get_client() as client: try: ma = await client.get_model_artifact(id) except mr_exceptions.NotFoundException: return None return ModelArtifact.from_basemodel(ma)
@overload async def get_model_artifact_by_params( self, name: str, model_version_id: str, ): ... @overload async def get_model_artifact_by_params(self, *, external_id: str): ...
[docs] @required_args( ( "name", "model_version_id", ), ("external_id",), ) async def get_model_artifact_by_params( self, name: str | None = None, model_version_id: str | None = None, external_id: str | None = None, ) -> ModelArtifact | None: """Fetch a model artifact either by external ID or by its name and the ID of its associated model version. Args: name: Model artifact name. model_version_id: ID of the associated model version. external_id: Model artifact external ID. Returns: Model artifact. """ async with self.get_client() as client: try: ma = await client.find_model_artifact( name=name, parent_resource_id=model_version_id, external_id=external_id, ) except mr_exceptions.NotFoundException: return None return ModelArtifact.from_basemodel(ma)
[docs] async def get_model_artifacts( self, model_version_id: str | None = None, options: ListOptions | None = None, ) -> list[ModelArtifact]: """Fetches model artifacts. Args: model_version_id: ID of the associated model version. options: Options for listing model artifacts. Returns: Model artifacts. """ async with self.get_client() as client: if model_version_id: art_list = await client.get_model_version_artifacts( model_version_id, **(options or ListOptions()).as_options() ) if options: options.next_page_token = art_list.next_page_token models = [] for art in art_list.items or []: converted = Artifact.validate_artifact(art) if isinstance(converted, ModelArtifact): models.append(converted) return models ma_list = await client.get_model_artifacts(**(options or ListOptions()).as_options()) if options: options.next_page_token = ma_list.next_page_token return [ModelArtifact.from_basemodel(ma) for ma in ma_list.items or []]
[docs] async def get_model_version_artifacts( self, model_version_id: str, options: ListOptions | None = None, ) -> list[Artifact]: """Fetches model artifacts. Args: model_version_id: ID of the associated model version. options: Options for listing model artifacts. Returns: Model artifacts. """ async with self.get_client() as client: art_list = await client.get_model_version_artifacts( model_version_id, **(options or ListOptions()).as_options() ) if options: options.next_page_token = art_list.next_page_token return [Artifact.validate_artifact(art) for art in art_list.items or []]
[docs] async def upsert_experiment(self, experiment: Experiment) -> Experiment: """Upsert an experiment. Updates or creates an experiment on the server. Args: experiment: Experiment to upsert. """ async with self.get_client() as client: if experiment.id: exp = await client.update_experiment(experiment.id, experiment.update()) elif experiment.name: if exp := await self.get_experiment_by_name(experiment.name): exp = await client.update_experiment(exp.id, experiment.update()) else: exp = await client.create_experiment(experiment.create()) return Experiment.from_basemodel(exp)
[docs] async def get_experiment_by_name(self, name: str) -> Experiment | None: """Fetch an experiment by its name. Args: name: Experiment name. """ async with self.get_client() as client: try: exp = await client.find_experiment(name=name) except mr_exceptions.NotFoundException: return None return Experiment.from_basemodel(exp)
[docs] async def get_experiment_by_id(self, id: str | int) -> Experiment | None: """Fetch an experiment by its ID. Args: id: Experiment ID. """ async with self.get_client() as client: try: exp = await client.get_experiment(str(id)) except mr_exceptions.NotFoundException: return None return RegisteredModel.from_basemodel(exp)
[docs] async def get_experiments(self, options: ListOptions | None = None) -> list[Experiment]: """Fetch experiments. Args: options: Options for listing experiments. """ async with self.get_client() as client: exp_list = await client.get_experiments(**(options or ListOptions()).as_options()) if options: options.next_page_token = exp_list.next_page_token return [Experiment.from_basemodel(exp) for exp in exp_list.items or []]
[docs] async def upsert_experiment_run(self, experiment_run: ExperimentRun) -> ExperimentRun: """Upsert an experiment run. Updates or creates an experiment run on the server. Args: experiment_run: Experiment run to upsert. """ async with self.get_client() as client: if experiment_run.id: exp_run = await client.create_experiment_run(experiment_run.id, experiment_run.update()) else: exp_run = await client.create_experiment_run(experiment_run.create()) return ExperimentRun.from_basemodel(exp_run)
[docs] async def get_experiment_runs_by_experiment_id( self, experiment_id: str | int, options: ListOptions | None = None ) -> list[ExperimentRun]: """Fetch experiment runs by experiment ID. Args: experiment_id: Experiment ID. options: Options for listing experiment runs. """ async with self.get_client() as client: try: exp_runs = await client.get_experiment_experiment_runs( str(experiment_id), **(options or ListOptions()).as_options() ) except mr_exceptions.NotFoundException: return [] if options: options.next_page_token = exp_runs.next_page_token return [ExperimentRun.from_basemodel(exp_run) for exp_run in exp_runs.items or []]
[docs] async def get_experiment_runs_by_experiment_name( self, experiment_name: str, options: ListOptions | None = None ) -> list[ExperimentRun]: """Fetch experiment runs by experiment name. Args: experiment_name: Experiment run to upsert. options: Options for listing experiment runs. """ async with self.get_client() as client: try: exp = await self.get_experiment_by_name(experiment_name) if not exp: return [] exp_runs = await client.get_experiment_experiment_runs( str(exp.id), **(options or ListOptions()).as_options() ) except mr_exceptions.NotFoundException: return [] if options: options.next_page_token = exp_runs.next_page_token return [ExperimentRun.from_basemodel(exp_run) for exp_run in exp_runs.items or []]
[docs] async def get_experiment_run_by_experiment_and_run_id( self, run_id: str | int, experiment_name: str | None = None, experiment_id: str | int | None = None, ) -> ExperimentRun: """Fetch experiment run by experiment name / ID and the run ID. Args: run_id: Run ID. experiment_name: Experiment name. experiment_id: Experiment ID. options: Options for listing experiment runs. Returns: Experiment run. """ async with self.get_client() as client: try: if experiment_name: exp = await self.get_experiment_by_name(experiment_name) elif experiment_id: exp = await self.get_experiment_by_id(str(experiment_id)) else: msg = "Either experiment_name or experiment_id must be provided" raise ValueError(msg) if not exp: return None exp_run = await client.get_experiment_run(str(run_id)) except mr_exceptions.NotFoundException: return None return ExperimentRun.from_basemodel(exp_run)
[docs] async def get_experiment_run_by_experiment_and_run_name( self, run_name: str, experiment_name: str | None = None, experiment_id: str | int | None = None, options: ListOptions | None = None, ) -> ExperimentRun: """Fetch experiment runs by experiment name / ID and the run ID. Args: run_name: Run name. experiment_name: Experiment name. experiment_id: Experiment ID. options: Options for listing experiment runs. Returns: Experiment run. """ async with self.get_client() as client: exp = None try: if experiment_name: exp = await self.get_experiment_by_name(experiment_name) elif experiment_id: exp = await self.get_experiment_by_id(str(experiment_id)) if not exp: return None exp_run = await client.get_experiment_run(exp.id) except mr_exceptions.NotFoundException: return None return ExperimentRun.from_basemodel(exp_run)
[docs] async def get_experiment_run_by_id(self, id: str) -> ExperimentRun: """Fetch an experiment run by its ID. Args: id: Experiment run ID. """ async with self.get_client() as client: try: exp_run = await client.get_experiment_run(id) except mr_exceptions.NotFoundException: return None return ExperimentRun.from_basemodel(exp_run)
[docs] async def upsert_experiment_run_artifact( self, experiment_run_id: str, artifact: ExperimentRunArtifact ) -> ExperimentRunArtifact: """Upsert an experiment run artifact (parameter, metric, or dataset). Updates or creates an experiment run on the server. Args: experiment_run_id: Experiment run ID. artifact: Artifact to upsert. """ async with self.get_client() as client: return Artifact.validate_artifact( await client.upsert_experiment_run_artifact( experimentrun_id=experiment_run_id, artifact=artifact.wrap() ) )
@overload async def get_artifacts_by_experiment_run_params(self, run_id: str | int, options: ListOptions | None = None): ... @overload async def get_artifacts_by_experiment_run_params( self, run_name: str, experiment_name: str | None = None, options: ListOptions | None = None, ): ... @overload async def get_artifacts_by_experiment_run_params( self, run_name: str, experiment_id: str | int | None = None, options: ListOptions | None = None, ): ...
[docs] @required_args(("run_id",), ("run_name", "experiment_name"), ("run_name", "experiment_id")) async def get_artifacts_by_experiment_run_params( self, run_id: str | int | None = None, run_name: str | None = None, experiment_name: str | None = None, experiment_id: str | int | None = None, *, options: ListOptions | None = None, ) -> ExperimentRunArtifact: """Fetch a log by experiment run ID and name. Args: run_id: Experiment run ID. run_name: Experiment run name. experiment_name: Experiment name. experiment_id: Experiment ID. Keyword Args: options: Options for listing experiment run artifacts. """ async with self.get_client() as client: try: if not run_id and run_name: if experiment_name: exp_runs = await self.get_experiment_runs_by_experiment_name( experiment_name=experiment_name, options=ListOptions(limit=100), ) elif experiment_id: exp_runs = await self.get_experiment_runs_by_experiment_id( experiment_id=experiment_id, options=ListOptions(limit=100), ) else: msg = "Either experiment_name or experiment_id must be provided" raise ValueError(msg) run = next((r for r in exp_runs if r.name == run_name), None) if not run: print( f"Could not find run {run_name} " f"in experiment {experiment_name} within the first 100 runs. " "Please narrow your search by run id." ) return [] run_id = run.id logs = await client.get_experiment_run_artifacts( str(run_id), **(options or ListOptions()).as_options() ) except mr_exceptions.NotFoundException: return [] if options: options.next_page_token = logs.next_page_token return [Artifact.validate_artifact(log) for log in logs.items or []]
[docs] async def get_artifacts( self, filter_query: str | None = None, artifact_type: str | None = None, options: ListOptions | None = None, ) -> list[Artifact]: """Get artifacts with filtering capabilities. Args: filter_query: A SQL-like query string to filter the list of entities. artifact_type: Specifies the artifact type for listing artifacts. options: Options for listing artifacts. Returns: List of artifacts matching the filter criteria. """ # Create options dict with additional filters options_dict = (options or ListOptions()).as_options() if filter_query is not None: options_dict["filter_query"] = filter_query if artifact_type is not None: options_dict["artifact_type"] = artifact_type async with self.get_client() as client: try: response = await client.get_artifacts(**options_dict) except mr_exceptions.NotFoundException: return [] if options: options.next_page_token = response.next_page_token return [Artifact.validate_artifact(artifact) for artifact in response.items or []]