felicity-lims/felicity/database/base_class.py
2023-04-07 17:52:19 +02:00

569 lines
18 KiB
Python

import logging
from base64 import b64decode, b64encode
from typing import Any, AsyncIterator, Dict, List, Optional, TypeVar, Union
from pydantic import BaseModel as PydanticBaseModel
from sqlalchemy import Column
from sqlalchemy import or_ as sa_or_
from sqlalchemy import update
from sqlalchemy.future import select
from sqlalchemy.orm import as_declarative, declared_attr, selectinload
from sqlalchemy.sql import func
from felicity.core.uid_gen import FelicitySAID, get_flake_uid
from felicity.database.async_mixins import (
AllFeaturesMixin,
ModelNotFoundError,
smart_query,
)
from felicity.database.paginator.cursor import EdgeNode, PageCursor, PageInfo
from felicity.database.session import async_session_factory
from felicity.utils import has_value_or_is_truthy
InDBSchemaType = TypeVar("InDBSchemaType", bound=PydanticBaseModel)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
"""?? Usefull Tools
1. https://docs.sqlalchemy.org/en/14/orm/session_events.html#adding-global-where-on-criteria
Maybe for global department filters ?
2. https://stackoverflow.com/questions/12753450/sqlalchemy-mixins-and-event-listener
SQLAlchemy "event.listen" for all models
event.listen(MyBaseMixin, 'before_insert', get_created_by_id, propagate=True)
event.listen(MyBaseMixin, 'before_update', get_updated_by_id, propagate=True)
"""
@as_declarative()
class DBModel(AllFeaturesMixin):
__name__: str
__abstract__ = True
__mapper_args__ = {"eager_defaults": True}
uid = Column(
FelicitySAID,
primary_key=True,
index=True,
nullable=False,
default=get_flake_uid,
)
@declared_attr
def __tablename__(cls) -> str:
""" "
Generate tablename automatically
"""
return cls.__name__.lower()
def marshal_simple(self, exclude=None):
"""convert instance to dict
leverages instance.__dict__
"""
if exclude is None:
exclude = []
exclude.append("_sa_instance_state")
data = self.__dict__
return_data = {}
for field in data:
if field not in exclude:
return_data[field] = data[field] # getattr(self, field)
return return_data
def marshal_nested(self, obj=None):
if obj is None:
obj = self
if isinstance(obj, dict):
return {k: self.marshal_nested(v) for k, v in obj.items()}
elif hasattr(obj, "_ast"):
return self.marshal_nested(obj._ast())
elif not isinstance(obj, str) and hasattr(obj, "__iter__"):
return [self.marshal_nested(v) for v in obj]
elif hasattr(obj, "__dict__"):
return {
k: self.marshal_nested(v)
for k, v in obj.__dict__.items()
if not callable(v) and not k.startswith("_")
}
else:
return obj
@classmethod
async def all_by_page(cls, page: int = 1, limit: int = 20, **kwargs) -> Dict:
start = (page - 1) * limit
stmt = cls.where(**kwargs).limit(limit).offset(start)
async with async_session_factory() as session:
results = await session.execute(stmt)
found = results.scalars().all()
return found
async def delete(self):
"""Removes the model from the current entity session and mark for deletion."""
async with async_session_factory() as session:
await session.delete(self)
await session.flush()
await session.commit()
@classmethod
async def destroy(cls, *ids):
"""Delete the records with the given ids
:type ids: list
:param ids: primary key ids of records
"""
for pk in ids:
obj = await cls.find(pk)
if obj:
await obj.delete()
async with async_session_factory() as session:
await session.flush()
@classmethod
async def all(cls):
async with async_session_factory() as session:
result = await session.execute(select(cls))
_all = result.scalars().all()
return _all
@classmethod
async def first(cls):
async with async_session_factory() as session:
result = await session.execute(select(cls))
_first = result.scalars().first()
return _first
@classmethod
async def find(cls, id_):
"""Find record by the id
:param id_: the primary key
"""
stmt = cls.where(uid=id_)
async with async_session_factory() as session:
results = await session.execute(stmt)
one_or_none = results.scalars().one_or_none()
return one_or_none
@classmethod
async def find_or_fail(cls, id_):
# assume that query has custom get_or_fail method
result = await cls.find(id_)
if result:
return result
else:
raise ModelNotFoundError(
"{} with uid '{}' was not found".format(cls.__name__, id_)
)
@classmethod
async def get(cls, **kwargs):
"""Return the first value in database based on given args.
Example:
User.get(id=5)
"""
# stmt = select(cls).where(**kwargs)
stmt = cls.where(**kwargs)
async with async_session_factory() as session:
results = await session.execute(stmt)
found = results.scalars().first()
return found
@staticmethod
async def db_commit():
async with async_session_factory() as session:
await session.commit()
@staticmethod
async def db_flush():
async with async_session_factory() as session:
await session.flush()
@classmethod
async def create(cls, **kwargs):
"""Returns a new get instance of the class
This is so that mutations can work well and prevent async IO issues
"""
fill = cls().fill(**kwargs)
created = await cls.save(fill)
if created:
created = await cls.get(uid=created.uid)
return created
@classmethod
async def bulk_create(cls, items: List):
"""
@param items a list of Pydantic models
"""
to_save = []
for data in items:
to_save.append(cls().fill(**cls._import(data)))
return await cls.save_all(to_save)
async def update(self, **kwargs):
"""Returns a new get instance of the class
This is so that mutations can work well and prevent async IO issues
"""
fill = self.fill(**kwargs)
updated = await fill.save()
if updated:
updated = await self.get(uid=updated.uid)
return updated
@classmethod
async def bulk_update_where(cls, update_data: List, filters: Dict):
"""
@param update_data a List of dictionary update values.
@param filters is a dict of filter values.
e.g [{'uid': 34, update_values}, ...]
"""
to_update = [cls._import(data) for data in update_data]
# stmt = update(cls).where(filters).values(to_save).execution_options(synchronize_session="fetch")
query = smart_query(query=update(cls), filters=filters)
stmt = query.values(to_update).execution_options(synchronize_session="fetch")
async with async_session_factory() as session:
results = await session.execute(stmt)
updated = results.scalars().all()
return updated
@classmethod
async def bulk_update_with_mappings(cls, mappings: List) -> None:
"""
@param mappings a List of dictionary update values with pks.
e.g [{'uid': 34, update_values}, ...]
?? there must be zero many-to-many relations
NB: Function does not return anything
"""
if len(mappings) == 0:
return
from sqlalchemy.sql.expression import bindparam
to_update = [cls._import(data) for data in mappings]
for item in to_update:
item["_uid"] = item["uid"]
query = update(cls).where(cls.uid == bindparam("_uid"))
binds = {}
for key in to_update[0]:
if key != "_uid":
binds[key] = bindparam(key)
stmt = query.values(binds).execution_options(synchronize_session="fetch")
async with async_session_factory() as session:
await session.execute(stmt, to_update)
await session.flush()
await session.commit()
@classmethod
async def bulk_update_with_mappings_not_working(cls, mappings: List):
"""
@param mappings a List of dictionary update values with pks.
e.g [{'uid': 34, update_values}, ...]
"""
to_update = [cls._import(data) for data in mappings]
async with async_session_factory() as session:
await session.bulk_update_mappings(cls, to_update)
await session.flush()
await session.commit()
return to_update
@classmethod
async def table_insert(cls, table, mappings):
"""
@param mappings a dictionary update values.
e.g {'name': 34, 'day': "fff"}
"""
async with async_session_factory() as session:
stmt = table.insert()
await session.execute(stmt, mappings)
await session.commit()
await session.flush()
@classmethod
async def get_related(cls, related: Optional[list] = None, list=False, **kwargs):
"""Return the first value in database based on given args."""
try:
del kwargs["related"]
except KeyError:
pass
try:
del kwargs["list"]
except KeyError:
pass
stmt = cls.where(**kwargs)
if related:
stmt.options(selectinload(related))
async with async_session_factory() as session:
results = await session.execute(stmt)
if not list:
found = results.scalars().first()
else:
found = results.scalars().all()
return found
@classmethod
def _import(cls, schema_in: Union[InDBSchemaType, Dict]):
"""Convert Pydantic schema to dict"""
if isinstance(schema_in, dict):
return schema_in
data = schema_in.dict(exclude_unset=True)
return data
async def save(self):
"""Saves the updated model to the current entity db."""
async with async_session_factory() as session:
try:
session.add(self)
await session.flush()
await session.commit()
except Exception:
await session.rollback()
raise
return self
async def flush_commit_session(self):
"""Saves the updated model to the current entity db."""
async with async_session_factory() as session:
try:
await session.flush()
await session.commit()
except Exception:
await session.rollback()
raise
return self
@classmethod
async def save_all(cls, items):
async with async_session_factory() as session:
try:
session.add_all(items)
await session.flush()
await session.commit()
except Exception:
await session.rollback()
raise
return items
@classmethod
async def get_one(cls, **kwargs):
stmt = cls.where(**kwargs)
async with async_session_factory() as session:
results = await session.execute(stmt)
found = results.scalars().first()
return found
@classmethod
async def get_all(cls, **kwargs):
stmt = cls.where(**kwargs)
async with async_session_factory() as session:
results = await session.execute(stmt)
return results.unique().scalars().all()
@classmethod
async def from_smart_query(cls, query):
async with async_session_factory() as session:
results = await session.execute(query)
return results.unique().scalars().all()
@classmethod
async def count_where(cls, filters):
"""
:param filters:
:return: int
"""
# stmt = smart_query(select(cls), filters=filters)
# stmt = select(func.count(cls.uid))
# stmt = select(func.count('*')).select_from(cls)
# stmt = select(cls, func.count(cls.uid))
# stmt = select(cls).with_only_columns([func.count(cls.uid)]).order_by(None)
# stmt = select(func.count()).select_from(cls)
# stmt = select(func.count()).select_from(select(cls).subquery())
# stmt = select(func.count(cls.uid)).select_from(cls)
filter_stmt = smart_query(query=select(cls), filters=filters)
count_stmt = select(func.count(filter_stmt.c.uid)).select_from(filter_stmt)
async with async_session_factory() as session:
res = await session.execute(count_stmt)
count = res.scalars().one()
return count
@classmethod
async def fulltext_search(cls, search_string, field):
"""Full-text Search with PostgreSQL"""
stmt = select(cls).filter(
func.to_tsvector("english", getattr(cls, field)).match(
search_string, postgresql_regconfig="english"
)
)
async with async_session_factory() as session:
results = await session.execute(stmt)
search = results.scalars().all()
return search
@classmethod
async def get_by_uids(cls, uids: List[Any]):
stmt = select(cls).where(cls.uid.in_(uids)) # type: ignore
async with async_session_factory() as session:
results = await session.execute(stmt.order_by(cls.uid))
return results.scalars().all()
@classmethod
async def stream_by_uids(cls, uids: List[Any]) -> AsyncIterator[Any]:
stmt = select(cls).where(cls.uid.in_(uids)) # type: ignore
async with async_session_factory() as session:
stream = await session.stream(stmt.order_by(cls.uid))
async for row in stream:
yield row
@classmethod
async def stream_all(cls) -> AsyncIterator[Any]:
stmt = select(cls)
async with async_session_factory() as session:
stream = await session.stream(stmt.order_by(cls.uid))
async for row in stream:
yield row
@staticmethod
def psql_records_to_dict(self, records, many=False):
# records._row: asyncpg.Record / databases.backends.postgres.Record
if not many and records:
return dict(records)
return [dict(record) for record in records]
# https://engage.so/blog/a-deep-dive-into-offset-and-cursor-based-pagination-in-mongodb/
# https://medium.com/swlh/how-to-implement-cursor-pagination-like-a-pro-513140b65f32
@classmethod
async def paginate_with_cursors(
cls,
page_size: int = None,
after_cursor: Any = None,
before_cursor: Any = None,
filters: Any = None,
sort_by: List[str] = None,
get_related: str = None,
) -> PageCursor:
if not filters:
filters = {}
# get total count without paging filters from cursors
total_count: int = await cls.count_where(filters=filters)
total_count = total_count if total_count else 0
cursor_limit = {}
if has_value_or_is_truthy(after_cursor):
cursor_limit = {"uid__gt": cls.decode_cursor(after_cursor)}
if has_value_or_is_truthy(before_cursor):
cursor_limit = {"uid__lt": cls.decode_cursor(before_cursor)}
# add paging filters
_filters = None
if isinstance(filters, dict):
_filters = [{sa_or_: cursor_limit}, filters] if cursor_limit else filters
elif isinstance(filters, list):
_filters = filters
if cursor_limit:
_filters.append({sa_or_: cursor_limit})
stmt = cls.smart_query(filters=_filters, sort_attrs=sort_by)
if get_related:
stmt = stmt.options(selectinload(get_related))
if page_size:
stmt = stmt.limit(page_size)
async with async_session_factory() as session:
res = await session.execute(stmt)
qs = res.scalars().all()
if qs is not None:
# items = qs[:page_size]
items = qs
else:
qs = []
items = []
has_additional = (
len(items) == page_size if page_size else True
) # len(qs) > len(items)s
page_info = {
"start_cursor": cls.encode_cursor(items[0].uid) if items else None,
"end_cursor": cls.encode_cursor(items[-1].uid) if items else None,
}
if page_size is not None:
page_info["has_next_page"] = has_additional
page_info["has_previous_page"] = bool(after_cursor)
return PageCursor(
**{
"total_count": total_count,
"edges": cls.build_edges(items=items),
"items": items,
"page_info": cls.build_page_info(**page_info),
}
)
@classmethod
def build_edges(cls, items: List[Any]) -> List[EdgeNode]:
if not items:
return []
return [cls.build_node(item) for item in items]
@classmethod
def build_node(cls, item: Any) -> EdgeNode:
return EdgeNode(**{"cursor": cls.encode_cursor(item.uid), "node": item})
@classmethod
def build_page_info(
cls,
start_cursor: str = None,
end_cursor: str = None,
has_next_page: bool = False,
has_previous_page: bool = False,
) -> PageInfo:
return PageInfo(
**{
"start_cursor": start_cursor,
"end_cursor": end_cursor,
"has_next_page": has_next_page,
"has_previous_page": has_previous_page,
}
)
@classmethod
def decode_cursor(cls, cursor):
decoded = b64decode(cursor.encode("ascii")).decode("utf8")
try:
return int(decoded)
except Exception:
return decoded
@classmethod
def encode_cursor(cls, identifier: Any):
return b64encode(str(identifier).encode("utf8")).decode("ascii")