mirror of
https://github.com/beak-insights/felicity-lims.git
synced 2025-02-25 01:13:01 +08:00
397 lines
13 KiB
Python
397 lines
13 KiB
Python
|
from typing import Generic, TypeVar, Any, List, AsyncIterator, Optional
|
||
|
|
||
|
try:
|
||
|
from typing import Self
|
||
|
except ImportError:
|
||
|
from typing_extensions import Self
|
||
|
|
||
|
from sqlalchemy import or_ as sa_or_
|
||
|
from sqlalchemy import select
|
||
|
from sqlalchemy import update
|
||
|
from sqlalchemy.sql import func
|
||
|
from sqlalchemy.sql.expression import bindparam
|
||
|
|
||
|
from felicity.apps.common.utils.serializer import marshaller
|
||
|
from felicity.database.paginator.cursor import PageCursor, EdgeNode, PageInfo
|
||
|
from felicity.database.queryset import (
|
||
|
QueryBuilder,
|
||
|
settable_attributes,
|
||
|
smart_query,
|
||
|
)
|
||
|
from felicity.database.session import async_session
|
||
|
|
||
|
M = TypeVar("M")
|
||
|
|
||
|
|
||
|
class BaseRepository(Generic[M]):
|
||
|
async_session = async_session
|
||
|
model: M = None
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
self._qb = QueryBuilder(model=self.model)
|
||
|
|
||
|
@staticmethod
|
||
|
def fill(m: M, **kwargs):
|
||
|
for name in kwargs.keys():
|
||
|
if name in settable_attributes(m):
|
||
|
setattr(m, name, kwargs[name])
|
||
|
else:
|
||
|
raise KeyError("Attribute '{}' doesn't exist".format(name))
|
||
|
return m
|
||
|
|
||
|
async def save(self, m: M) -> M:
|
||
|
async with self.async_session() as session:
|
||
|
try:
|
||
|
session.add(m)
|
||
|
await session.flush()
|
||
|
await session.commit()
|
||
|
except Exception:
|
||
|
await session.rollback()
|
||
|
raise
|
||
|
return m
|
||
|
|
||
|
@classmethod
|
||
|
async def save_all(cls, items):
|
||
|
async with cls.async_session() as session:
|
||
|
try:
|
||
|
session.add_all(items)
|
||
|
await session.flush()
|
||
|
await session.commit()
|
||
|
except Exception:
|
||
|
await session.rollback()
|
||
|
raise
|
||
|
return items
|
||
|
|
||
|
async def create(self, **kwargs) -> M:
|
||
|
cls = self.model()
|
||
|
filled = self.fill(cls, **kwargs)
|
||
|
return await self.save(filled)
|
||
|
|
||
|
async def bulk_create(self, bulk: list[dict]) -> list[M]:
|
||
|
to_save = []
|
||
|
for data in bulk:
|
||
|
fill = await self.fill(self.model(), **data)
|
||
|
to_save.append(fill)
|
||
|
return await self.save_all(to_save)
|
||
|
|
||
|
async def update(self, model: M, **kwargs) -> M:
|
||
|
filled = self.fill(model, **kwargs)
|
||
|
return await self.save(filled)
|
||
|
|
||
|
async def update_by_uid(self, uid: str, **kwargs) -> M:
|
||
|
_update = await self.get(uid=uid)
|
||
|
filled = self.fill(_update, **kwargs)
|
||
|
return await self.save(filled)
|
||
|
|
||
|
async def bulk_update_where(self, update_data: list[dict], filters: dict):
|
||
|
"""
|
||
|
@param update_data a List of dictionary update values.
|
||
|
@param filters is a dict of filter values.
|
||
|
e.g [{'uid': 34, update_values}, ...]
|
||
|
"""
|
||
|
query = smart_query(query=update(self.model), filters=filters)
|
||
|
stmt = query.values(update_data).execution_options(synchronize_session="fetch")
|
||
|
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(stmt)
|
||
|
updated = results.scalars().all()
|
||
|
return updated
|
||
|
|
||
|
async def bulk_update_with_mappings(self, mappings: list) -> None:
|
||
|
"""
|
||
|
@param mappings a List of dictionary update values with pks.
|
||
|
e.g [{'uid': 34, update_values}, ...]
|
||
|
?? there must be zero many-to-many relations
|
||
|
NB: Function does not return anything
|
||
|
"""
|
||
|
if len(mappings) == 0:
|
||
|
return
|
||
|
|
||
|
to_update = [marshaller(data) for data in mappings]
|
||
|
for item in to_update:
|
||
|
item["_uid"] = item["uid"]
|
||
|
|
||
|
query = update(self.model).where(self.model.uid == bindparam("_uid"))
|
||
|
|
||
|
binds = {}
|
||
|
for key in to_update[0]:
|
||
|
if key != "_uid":
|
||
|
binds[key] = bindparam(key)
|
||
|
|
||
|
stmt = query.values(binds).execution_options(
|
||
|
synchronize_session=None
|
||
|
) # "fetch" not available
|
||
|
|
||
|
async with self.async_session() as session:
|
||
|
await session.execute(stmt, to_update)
|
||
|
await session.flush()
|
||
|
await session.commit()
|
||
|
|
||
|
async def table_insert(self, table: Any, mappings: list[dict]) -> None:
|
||
|
"""
|
||
|
@param table is a sqlalchemy table model
|
||
|
@param mappings a dictionary update values.
|
||
|
e.g {'name': 34, 'day': "fff"}
|
||
|
"""
|
||
|
async with self.async_session() as session:
|
||
|
stmt = table.insert()
|
||
|
await session.execute(stmt, mappings)
|
||
|
await session.commit()
|
||
|
await session.flush()
|
||
|
|
||
|
async def query_table(self, table, **kwargs):
|
||
|
stmt = select(table)
|
||
|
for k, v in kwargs.items():
|
||
|
stmt = stmt.where(table.c[k] == v)
|
||
|
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(stmt)
|
||
|
return results.unique().scalars().all()
|
||
|
|
||
|
async def get(self, **kwargs) -> M:
|
||
|
stmt = self._qb.where(**kwargs)
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(stmt)
|
||
|
found = results.scalars().first()
|
||
|
return found
|
||
|
|
||
|
async def get_all(self, **kwargs) -> list[M]:
|
||
|
stmt = self._qb.where(**kwargs)
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(stmt)
|
||
|
found = results.scalars().all()
|
||
|
return found
|
||
|
|
||
|
async def all(self) -> list[M]:
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(select(self.model))
|
||
|
return results.scalars().all()
|
||
|
|
||
|
async def all_by_page(self, page: int = 1, limit: int = 20, **kwargs) -> dict:
|
||
|
start = (page - 1) * limit
|
||
|
|
||
|
stmt = self._qb.where(**kwargs).limit(limit).offset(start)
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(stmt)
|
||
|
found = results.scalars().all()
|
||
|
return found
|
||
|
|
||
|
async def get_by_uids(self, uids: List[str]) -> list[M]:
|
||
|
stmt = select(self.model).where(self.model.uid.in_(uids)) # type: ignore
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(stmt.order_by(self.model.uid))
|
||
|
return results.scalars().all()
|
||
|
|
||
|
async def get_related(
|
||
|
self, related: Optional[list] = None, many: bool = False, **kwargs
|
||
|
):
|
||
|
"""Return the first value in database based on given args."""
|
||
|
try:
|
||
|
del kwargs["related"]
|
||
|
except KeyError:
|
||
|
pass
|
||
|
|
||
|
try:
|
||
|
del kwargs["many"]
|
||
|
except KeyError:
|
||
|
pass
|
||
|
|
||
|
stmt = self._qb.where(**kwargs)
|
||
|
# if related:
|
||
|
# stmt.options(selectinload(related))
|
||
|
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(stmt)
|
||
|
|
||
|
if not many:
|
||
|
found = results.scalars().first()
|
||
|
else:
|
||
|
found = results.scalars().all()
|
||
|
|
||
|
return found
|
||
|
|
||
|
async def stream_by_uids(self, uids: List[Any]) -> AsyncIterator[M]:
|
||
|
|
||
|
stmt = select(self.model).where(self.model.in_(uids)) # type: ignore
|
||
|
|
||
|
async with self.async_session() as session:
|
||
|
stream = await session.stream(stmt.order_by(self.model.uid))
|
||
|
async for row in stream:
|
||
|
yield row
|
||
|
|
||
|
async def stream_all(self) -> AsyncIterator[Any]:
|
||
|
stmt = select(self.model)
|
||
|
async with self.async_session() as session:
|
||
|
stream = await session.stream(stmt.order_by(self.model.uid))
|
||
|
async for row in stream:
|
||
|
yield row
|
||
|
|
||
|
async def full_text_search(self, search_string, field):
|
||
|
"""Full-text Search with PostgreSQL"""
|
||
|
stmt = select(self.model).filter(
|
||
|
func.to_tsvector("english", getattr(self.model, field)).match(
|
||
|
search_string, postgresql_regconfig="english"
|
||
|
)
|
||
|
)
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(stmt)
|
||
|
search = results.scalars().all()
|
||
|
return search
|
||
|
|
||
|
async def delete(self, uid: str) -> None:
|
||
|
obj = await self.get(uid=uid)
|
||
|
async with self.async_session() as session:
|
||
|
await session.delete(obj)
|
||
|
await session.flush()
|
||
|
await session.commit()
|
||
|
|
||
|
async def count_where(self, filters: dict) -> int:
|
||
|
"""
|
||
|
:param filters:
|
||
|
:return: int
|
||
|
"""
|
||
|
# filter_stmt = smart_query(query=select(cls), filters=filters) noqa
|
||
|
filter_stmt = self._qb.smart_query(filters=filters)
|
||
|
count_stmt = select(func.count(filter_stmt.c.uid)).select_from(filter_stmt)
|
||
|
async with self.async_session() as session:
|
||
|
res = await session.execute(count_stmt)
|
||
|
count = res.scalars().one()
|
||
|
return count
|
||
|
|
||
|
async def search(self, **kwargs) -> list[M]:
|
||
|
filters = []
|
||
|
combined = set()
|
||
|
for k, v in kwargs:
|
||
|
filter_string = f"{k}__ilike"
|
||
|
filters.append(filter_string)
|
||
|
|
||
|
arg = dict()
|
||
|
arg[k] = f"%{v}%"
|
||
|
query = await self.get_all(**arg)
|
||
|
for item in query:
|
||
|
combined.add(item)
|
||
|
|
||
|
return list(combined)
|
||
|
|
||
|
async def filter(
|
||
|
self,
|
||
|
filters: list[dict],
|
||
|
sort_attrs: list[str] | None = None,
|
||
|
limit: int | None = None,
|
||
|
either: bool = False,
|
||
|
) -> list[M]:
|
||
|
if either:
|
||
|
filters = {sa_or_: filters}
|
||
|
|
||
|
stmt = self._qb.smart_query(filters, sort_attrs)
|
||
|
if limit:
|
||
|
stmt = stmt.limit(limit)
|
||
|
async with self.async_session() as session:
|
||
|
results = await session.execute(stmt)
|
||
|
found = results.scalars().all()
|
||
|
return found
|
||
|
|
||
|
async def paginate_with_cursors(
|
||
|
self,
|
||
|
page_size: int | None,
|
||
|
after_cursor: str | None,
|
||
|
before_cursor: str | None,
|
||
|
filters: dict | list[dict] | None,
|
||
|
sort_by: list[str] | None,
|
||
|
**kwargs,
|
||
|
) -> PageCursor:
|
||
|
if not filters:
|
||
|
filters = {}
|
||
|
|
||
|
# get total count without paging filters from cursors
|
||
|
total_count: int = await self.count_where(filters=filters)
|
||
|
total_count = total_count if total_count else 0
|
||
|
|
||
|
cursor_limit = {}
|
||
|
if after_cursor:
|
||
|
cursor_limit = {"uid__gt": self.decode_cursor(after_cursor)}
|
||
|
|
||
|
if before_cursor:
|
||
|
cursor_limit = {"uid__lt": self.decode_cursor(before_cursor)}
|
||
|
|
||
|
# add paging filters
|
||
|
_filters = None
|
||
|
if isinstance(filters, dict):
|
||
|
_filters = [{sa_or_: cursor_limit}, filters] if cursor_limit else filters
|
||
|
elif isinstance(filters, list):
|
||
|
_filters = filters
|
||
|
if cursor_limit:
|
||
|
_filters.append({sa_or_: cursor_limit})
|
||
|
|
||
|
stmt = self._qb.smart_query(filters=_filters, sort_attrs=sort_by)
|
||
|
|
||
|
if kwargs.get("get_related"):
|
||
|
# stmt = stmt.options(selectinload(get_related)) noqa
|
||
|
pass
|
||
|
|
||
|
if page_size:
|
||
|
stmt = stmt.limit(page_size)
|
||
|
|
||
|
async with self.async_session() as session:
|
||
|
res = await session.execute(stmt)
|
||
|
|
||
|
qs = res.scalars().all()
|
||
|
|
||
|
if qs is not None:
|
||
|
items = qs
|
||
|
else:
|
||
|
qs = []
|
||
|
items = []
|
||
|
|
||
|
has_additional = (
|
||
|
len(items) == page_size if page_size else True
|
||
|
) # len(qs) > len(items)s
|
||
|
page_info = {
|
||
|
"start_cursor": self.encode_cursor(items[0].uid) if items else None,
|
||
|
"end_cursor": self.encode_cursor(items[-1].uid) if items else None,
|
||
|
}
|
||
|
if page_size is not None:
|
||
|
page_info["has_next_page"] = has_additional
|
||
|
page_info["has_previous_page"] = bool(after_cursor)
|
||
|
|
||
|
return PageCursor(
|
||
|
**{
|
||
|
"total_count": total_count,
|
||
|
"edges": self.build_edges(items=items),
|
||
|
"items": items,
|
||
|
"page_info": self.build_page_info(**page_info),
|
||
|
}
|
||
|
)
|
||
|
|
||
|
def build_edges(self, items: List[Any]) -> List[EdgeNode]:
|
||
|
if not items:
|
||
|
return []
|
||
|
return [self.build_node(item) for item in items]
|
||
|
|
||
|
def build_node(self, item: Any) -> EdgeNode:
|
||
|
return EdgeNode(**{"cursor": self.encode_cursor(item.uid), "node": item})
|
||
|
|
||
|
@staticmethod
|
||
|
def build_page_info(
|
||
|
start_cursor: str = None,
|
||
|
end_cursor: str = None,
|
||
|
has_next_page: bool = False,
|
||
|
has_previous_page: bool = False,
|
||
|
) -> PageInfo:
|
||
|
return PageInfo(
|
||
|
**{
|
||
|
"start_cursor": start_cursor,
|
||
|
"end_cursor": end_cursor,
|
||
|
"has_next_page": has_next_page,
|
||
|
"has_previous_page": has_previous_page,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def decode_cursor(cursor):
|
||
|
return cursor
|
||
|
|
||
|
@staticmethod
|
||
|
def encode_cursor(identifier: Any):
|
||
|
return identifier
|