Source code for nannyml.io.db.database_writer

from typing import Any, Dict, Optional

from sqlmodel import Session, SQLModel, create_engine, select

from nannyml._typing import Result
from nannyml.exceptions import WriterException
from nannyml.io.base import Writer, WriterFactory
from nannyml.io.db.entities import Model, Run
from nannyml.io.db.mappers import MapperFactory
from nannyml.usage_logging import UsageEvent, log_usage


[docs]@WriterFactory.register('database') # registration name matches property used in configuration file class DatabaseWriter(Writer): """A :class:`~nannyml.io.base.Writer` implementation that writes a ``Result`` into a database table. The ``Result`` class is transformed into a list of *DbMetric* objects by an appropriate :class:`~nannyml.io.db.mappers.Mapper` instance. These *DbMetrics* are written into a database table, specific to the ``Result`` class. This supports any database that is compatible with *SQLAlchemy*. """ def __init__( self, connection_string: str, connection_options: Optional[Dict[str, Any]] = None, model_name: Optional[str] = None, ): """ Creates a new ``DatabaseWriter`` instance. Parameters ---------- connection_string: str The connection string that configures the connection to the database. Might contain user credentials as well. connection_options: Dict[str, Any], default=None Additional options passed along to the underlying *SQLAlchemy* engine. model_name: str, default=None An optional name for the model being monitored. When given this will cause a record to be created in the ``models`` table and having each *DbMetric* link to that one. This allows easy filtering and dropdown population in data visualization tools in case of multiple models exporting into the same database structure. Examples -------- >>> # write to local in-memory database >>> sqlite_writer = DatabaseWriter(connection_string='sqlite:///', model_name='car_loan_prediction') >>> sqlite_writer.write(result) >>> postgres_writer = DatabaseWriter( ... connection_string='postgresql://postgres:mysecretpassword@localhost:5432/postgres', ... model_name='car_loan_prediction' ... ) >>> postgres_writer.write(result) """ super().__init__() self.connection_string = connection_string if connection_options is None: connection_options = {} self._engine = create_engine(url=connection_string, **connection_options) try: SQLModel.metadata.create_all(self._engine) # find or create a 'model' and store the id self.model_id = self._upsert_model(model_name) # create the "run" and store the id self.run_id = self._create_run(model_id=self.model_id) except Exception as exc: raise WriterException(f"could not create DatabaseWriter: {exc}") @log_usage(UsageEvent.WRITE_DB) def _write(self, result: Result, **kwargs): mapper = MapperFactory.create(result) with Session(self._engine) as session: metrics = mapper.map_to_entity(result, run_id=self.run_id, model_id=self.model_id) session.add_all(metrics) session.commit() def _create_run(self, **run_args) -> int: """Inserts a new record into the 'run' table and returns the id.""" run = Run(**run_args) with Session(self._engine) as session: session.add(run) session.commit() session.refresh(run) if run.id is None: raise RuntimeError("could not retrieve run identifier from the database") return run.id def _upsert_model(self, model_name: Optional[str] = None) -> Optional[int]: """Upsert a model given a model name, returns the model id.""" # No model specified if model_name is None: return None with Session(self._engine) as session: model = session.exec(select(Model).where(Model.name == model_name)).first() if model is None: self._logger.info(f"could not find a model with name '{model_name}', creating new") model = Model(name=model_name) session.add(model) session.commit() return model.id