
433 lines
15 KiB
Raw Normal View History

2024-07-28 03:52:31 +08:00
from typing import Any, AsyncIterator, Generic, List, Optional, TypeVar
2024-07-21 15:06:51 +08:00
from sqlalchemy import inspect, or_ as sa_or_, Table
2024-07-28 03:52:31 +08:00
from sqlalchemy import select, update
from sqlalchemy.orm import selectinload
2024-07-21 15:06:51 +08:00
from sqlalchemy.sql import func
from sqlalchemy.sql.expression import bindparam
from felicity.apps.abstract.entity import BaseEntity
2024-07-28 03:52:31 +08:00
from felicity.database.paging import EdgeNode, PageCursor, PageInfo
2024-07-21 15:06:51 +08:00
from felicity.database.session import async_session
2024-07-28 03:52:31 +08:00
M = TypeVar("M", bound=BaseEntity)
def apply_nested_loader_options(stmt, model, path):
Applies loader options to nested relationships based on a dot-separated path.
:param stmt: The SQLAlchemy query object.
:param model: The base model from which to start applying loader options.
:param path: A dot-separated string representing the nested relationship path.
:param loader_option: The loader option function (e.g., selectinload, joinedload).
:return: The modified query with loader options applied to nested relationships.
load_option = None
current_model = model
2024-07-28 16:25:31 +08:00
paths = path.split(".") if "." in path else [path]
2024-07-28 03:52:31 +08:00
2024-07-28 16:25:31 +08:00
for attr in paths:
2024-07-28 03:52:31 +08:00
if load_option is None:
load_option = selectinload(getattr(current_model, attr))
current_option = load_option
next_option = selectinload(getattr(current_model, attr))
current_option = current_option.options(next_option)
current_option = next_option
2024-07-28 03:52:31 +08:00
# Update the current model to the next model in the relationship path
current_model = inspect(current_model).relationships[attr].mapper.class_
return stmt.options(load_option)
2024-07-21 15:06:51 +08:00
class BaseRepository(Generic[M]):
async_session = async_session
model: M = None
2024-07-21 21:44:22 +08:00
def __init__(self, model: M) -> None:
self.model = model
2024-07-24 17:04:53 +08:00
2024-07-21 15:06:51 +08:00
async def save(self, m: M) -> M:
2024-07-28 16:25:31 +08:00
if not m:
raise ValueError("No model provided to save") # noqa
2024-07-21 15:06:51 +08:00
async with self.async_session() as session:
2024-07-28 03:52:31 +08:00
# try:
# session.add(m)
# except Exception:
# await session.merge(m)
2024-07-21 15:06:51 +08:00
await session.flush()
await session.commit()
except Exception:
await session.rollback()
return m
2024-07-24 04:30:01 +08:00
async def save_all(self, items):
2024-07-28 16:25:31 +08:00
if not items:
raise ValueError("No items provided to save")
2024-07-24 04:30:01 +08:00
async with self.async_session() as session:
2024-07-21 15:06:51 +08:00
await session.flush()
await session.commit()
except Exception:
await session.rollback()
return items
async def create(self, **kwargs) -> M:
2024-07-28 16:25:31 +08:00
if not kwargs:
raise ValueError("No data provided to create a new model")
2024-07-28 03:52:31 +08:00
filled = self.model.fill(self.model(), **kwargs)
2024-07-21 15:06:51 +08:00
return await self.save(filled)
async def bulk_create(self, bulk: list[dict]) -> list[M]:
2024-07-28 16:25:31 +08:00
if not bulk:
raise ValueError("No data provided to create a new models")
2024-07-21 15:06:51 +08:00
to_save = []
for data in bulk:
2024-07-28 03:52:31 +08:00
fill = self.model.fill(self.model(), **data)
2024-07-21 15:06:51 +08:00
return await self.save_all(to_save)
2024-07-24 17:04:53 +08:00
async def update(self, uid: str, **data) -> M:
2024-07-28 16:25:31 +08:00
if not uid or not data:
raise ValueError("Both uid and data are required to update model")
2024-07-24 17:04:53 +08:00
item = await self.get(uid=uid)
2024-07-28 03:52:31 +08:00
filled = self.model.fill(item, **data)
2024-07-21 15:06:51 +08:00
return await self.save(filled)
async def bulk_update_where(self, update_data: list[dict], filters: dict):
@param update_data a List of dictionary update values.
@param filters is a dict of filter values.
e.g [{'uid': 34, update_values}, ...]
2024-07-28 16:25:31 +08:00
if not update_data or not filters:
raise ValueError("Both update_data and filters are required to update model")
2024-07-28 03:52:31 +08:00
query = self.model.smart_query(query=update(self.model), filters=filters)
2024-07-21 15:06:51 +08:00
stmt = query.values(update_data).execution_options(synchronize_session="fetch")
async with self.async_session() as session:
results = await session.execute(stmt)
updated = results.scalars().all()
return updated
async def bulk_update_with_mappings(self, mappings: list) -> None:
@param mappings a List of dictionary update values with pks.
e.g [{'uid': 34, update_values}, ...]
?? there must be zero many-to-many relations
NB: Function does not return anything
2024-07-28 16:25:31 +08:00
if not mappings:
raise ValueError("No mappings provided to update")
2024-07-21 15:06:51 +08:00
to_update = mappings # [marshaller(data) for data in mappings]
2024-07-21 15:06:51 +08:00
for item in to_update:
item["_uid"] = item["uid"]
query = update(self.model).where(self.model.uid == bindparam("_uid"))
binds = {}
for key in to_update[0]:
if key != "_uid":
binds[key] = bindparam(key)
stmt = query.values(binds).execution_options(
) # "fetch" not available
async with self.async_session() as session:
await session.execute(stmt, to_update)
await session.flush()
await session.commit()
async def table_insert(self, table: Any, mappings: list[dict]) -> None:
@param table is a sqlalchemy table model
@param mappings a dictionary update values.
e.g {'name': 34, 'day': "fff"}
async with self.async_session() as session:
stmt = table.insert()
await session.execute(stmt, mappings)
await session.commit()
await session.flush()
async def query_table(self, table: Table, columns: list[str], **kwargs):
if table is None or not kwargs:
2024-07-28 16:25:31 +08:00
raise ValueError("Both table and filters are required to query")
if columns:
stmt = select(*(table.c[column] for column in columns))
stmt = select(table)
2024-07-21 15:06:51 +08:00
for k, v in kwargs.items():
stmt = stmt.where(table.c[k] == v)
async with self.async_session() as session:
results = await session.execute(stmt)
2024-07-24 17:04:53 +08:00
return results.unique().scalars().all() # , results.keys()
2024-07-21 15:06:51 +08:00
async def get(self, **kwargs) -> M:
2024-07-28 16:25:31 +08:00
if not kwargs:
raise ValueError("No arguments provided to get model")
2024-07-28 03:52:31 +08:00
stmt = self.model.where(**kwargs)
2024-07-21 15:06:51 +08:00
async with self.async_session() as session:
results = await session.execute(stmt)
found = results.scalars().first()
return found
async def get_all(self, **kwargs) -> list[M]:
2024-07-28 16:25:31 +08:00
if not kwargs:
raise ValueError("No arguments provided to get all")
2024-07-28 03:52:31 +08:00
stmt = self.model.where(**kwargs)
2024-07-21 15:06:51 +08:00
async with self.async_session() as session:
results = await session.execute(stmt)
found = results.scalars().all()
return found
async def all(self) -> list[M]:
async with self.async_session() as session:
results = await session.execute(select(self.model))
return results.scalars().all()
async def all_by_page(self, page: int = 1, limit: int = 20, **kwargs) -> dict:
start = (page - 1) * limit
2024-07-28 03:52:31 +08:00
stmt = self.model.where(**kwargs).limit(limit).offset(start)
2024-07-21 15:06:51 +08:00
async with self.async_session() as session:
results = await session.execute(stmt)
found = results.scalars().all()
return found
async def get_by_uids(self, uids: List[str]) -> list[M]:
2024-07-28 16:25:31 +08:00
if not uids:
raise ValueError("No uids provided to get by uids")
2024-07-21 15:06:51 +08:00
stmt = select(self.model).where(self.model.uid.in_(uids)) # type: ignore
async with self.async_session() as session:
results = await session.execute(stmt.order_by(self.model.uid))
return results.scalars().all()
async def get_related(
self, related: Optional[list[str]], many: bool = False, **kwargs
2024-07-21 15:06:51 +08:00
"""Return the first value in database based on given args."""
2024-07-28 16:25:31 +08:00
if not related:
raise ValueError("No related fields provided to get related")
2024-07-28 03:52:31 +08:00
stmt = self.model.where(**kwargs)
2024-07-28 16:25:31 +08:00
for key in related:
stmt = apply_nested_loader_options(stmt, self.model, key)
2024-07-21 15:06:51 +08:00
async with self.async_session() as session:
results = await session.execute(stmt)
if not many:
found = results.scalars().first()
found = results.scalars().all()
return found
async def stream_by_uids(self, uids: List[Any]) -> AsyncIterator[M]:
stmt = select(self.model).where(self.model.in_(uids)) # type: ignore
async with self.async_session() as session:
stream = await session.stream(stmt.order_by(self.model.uid))
async for row in stream:
yield row
async def stream_all(self) -> AsyncIterator[Any]:
stmt = select(self.model)
async with self.async_session() as session:
stream = await session.stream(stmt.order_by(self.model.uid))
async for row in stream:
yield row
async def full_text_search(self, search_string, field):
"""Full-text Search with PostgreSQL"""
stmt = select(self.model).filter(
func.to_tsvector("english", getattr(self.model, field)).match(
search_string, postgresql_regconfig="english"
async with self.async_session() as session:
results = await session.execute(stmt)
search = results.scalars().all()
return search
async def delete(self, uid: str) -> None:
2024-07-28 16:25:31 +08:00
if not uid:
raise ValueError("No uid provided to delete")
2024-07-21 15:06:51 +08:00
obj = await self.get(uid=uid)
async with self.async_session() as session:
await session.delete(obj)
await session.flush()
await session.commit()
async def count_where(self, filters: dict) -> int:
:param filters:
:return: int
2024-07-28 03:52:31 +08:00
filter_stmt = self.model.smart_query(filters=filters)
2024-07-21 15:06:51 +08:00
count_stmt = select(func.count(filter_stmt.c.uid)).select_from(filter_stmt)
async with self.async_session() as session:
res = await session.execute(count_stmt)
2024-07-24 17:04:53 +08:00
return res.scalars().one()
2024-07-21 15:06:51 +08:00
async def search(self, **kwargs) -> list[M]:
2024-07-28 16:25:31 +08:00
if not kwargs:
raise ValueError("No search arguments provided")
2024-07-21 15:06:51 +08:00
filters = []
combined = set()
for k, v in kwargs:
filter_string = f"{k}__ilike"
arg = dict()
arg[k] = f"%{v}%"
query = await self.get_all(**arg)
for item in query:
return list(combined)
async def filter(
filters: list[dict],
sort_attrs: list[str] | None = None,
limit: int | None = None,
either: bool = False,
2024-07-21 15:06:51 +08:00
) -> list[M]:
if either:
filters = {sa_or_: filters}
2024-07-28 03:52:31 +08:00
stmt = self.model.smart_query(filters, sort_attrs)
2024-07-21 15:06:51 +08:00
if limit:
stmt = stmt.limit(limit)
async with self.async_session() as session:
results = await session.execute(stmt)
found = results.scalars().all()
return found
2024-07-24 17:04:53 +08:00
async def paginate(
page_size: int | None,
after_cursor: str | None,
before_cursor: str | None,
filters: dict | list[dict] | None,
sort_by: list[str] | None,
2024-07-21 15:06:51 +08:00
) -> PageCursor:
if not filters:
filters = {}
# get total count without paging filters from cursors
total_count: int = await self.count_where(filters=filters)
total_count = total_count if total_count else 0
cursor_limit = {}
if after_cursor:
cursor_limit = {"uid__gt": self.decode_cursor(after_cursor)}
if before_cursor:
cursor_limit = {"uid__lt": self.decode_cursor(before_cursor)}
# add paging filters
_filters = None
if isinstance(filters, dict):
_filters = [{sa_or_: cursor_limit}, filters] if cursor_limit else filters
elif isinstance(filters, list):
_filters = filters
if cursor_limit:
_filters.append({sa_or_: cursor_limit})
2024-07-28 03:52:31 +08:00
stmt = self.model.smart_query(filters=_filters, sort_attrs=sort_by)
2024-07-21 15:06:51 +08:00
if kwargs.get("get_related"):
2024-07-28 03:52:31 +08:00
for key in kwargs.get("get_related"):
# stmt = stmt.options(selectinload(getattr(self.model, key)))
stmt = apply_nested_loader_options(stmt, self.model, key)
2024-07-21 15:06:51 +08:00
if page_size:
stmt = stmt.limit(page_size)
async with self.async_session() as session:
res = await session.execute(stmt)
qs = res.scalars().all()
if qs is not None:
items = qs
qs = []
items = []
has_additional = (
len(items) == page_size if page_size else True
) # len(qs) > len(items)s
page_info = {
"start_cursor": self.encode_cursor(items[0].uid) if items else None,
"end_cursor": self.encode_cursor(items[-1].uid) if items else None,
if page_size is not None:
page_info["has_next_page"] = has_additional
page_info["has_previous_page"] = bool(after_cursor)
return PageCursor(
"total_count": total_count,
"edges": self.build_edges(items=items),
"items": items,
"page_info": self.build_page_info(**page_info),
def build_edges(self, items: List[Any]) -> List[EdgeNode]:
if not items:
return []
return [self.build_node(item) for item in items]
def build_node(self, item: Any) -> EdgeNode:
return EdgeNode(**{"cursor": self.encode_cursor(item.uid), "node": item})
def build_page_info(
start_cursor: str = None,
end_cursor: str = None,
has_next_page: bool = False,
has_previous_page: bool = False,
2024-07-21 15:06:51 +08:00
) -> PageInfo:
return PageInfo(
"start_cursor": start_cursor,
"end_cursor": end_cursor,
"has_next_page": has_next_page,
"has_previous_page": has_previous_page,
def decode_cursor(cursor):
return cursor
def encode_cursor(identifier: Any):
return identifier