from typing import Any, AsyncIterator, Generic, List, Optional, TypeVar from sqlalchemy import inspect, or_ as sa_or_, Table from sqlalchemy import select, update from sqlalchemy.orm import selectinload from sqlalchemy.sql import func from sqlalchemy.sql.expression import bindparam from felicity.apps.abstract.entity import BaseEntity from felicity.database.paging import EdgeNode, PageCursor, PageInfo from felicity.database.session import async_session M = TypeVar("M", bound=BaseEntity) def apply_nested_loader_options(stmt, model, path): """ Applies loader options to nested relationships based on a dot-separated path. :param stmt: The SQLAlchemy query object. :param model: The base model from which to start applying loader options. :param path: A dot-separated string representing the nested relationship path. :param loader_option: The loader option function (e.g., selectinload, joinedload). :return: The modified query with loader options applied to nested relationships. select(Order).options( selectinload(Order.items).selectinload(Item.keywords) ) """ load_option = None current_model = model paths = path.split(".") if "." in path else [path] for attr in paths: if load_option is None: load_option = selectinload(getattr(current_model, attr)) current_option = load_option else: next_option = selectinload(getattr(current_model, attr)) current_option = current_option.options(next_option) current_option = next_option # Update the current model to the next model in the relationship path current_model = inspect(current_model).relationships[attr].mapper.class_ return stmt.options(load_option) class BaseRepository(Generic[M]): async_session = async_session model: M = None def __init__(self, model: M) -> None: self.model = model async def save(self, m: M) -> M: if not m: raise ValueError("No model provided to save") # noqa async with self.async_session() as session: try: session.add(m) # try: # session.add(m) # except Exception: # await session.merge(m) await session.flush() await session.commit() except Exception: await session.rollback() raise return m async def save_all(self, items): if not items: raise ValueError("No items provided to save") async with self.async_session() as session: try: session.add_all(items) await session.flush() await session.commit() except Exception: await session.rollback() raise return items async def create(self, **kwargs) -> M: if not kwargs: raise ValueError("No data provided to create a new model") filled = self.model.fill(self.model(), **kwargs) return await self.save(filled) async def bulk_create(self, bulk: list[dict]) -> list[M]: if not bulk: raise ValueError("No data provided to create a new models") to_save = [] for data in bulk: fill = self.model.fill(self.model(), **data) to_save.append(fill) return await self.save_all(to_save) async def update(self, uid: str, **data) -> M: if not uid or not data: raise ValueError("Both uid and data are required to update model") item = await self.get(uid=uid) filled = self.model.fill(item, **data) return await self.save(filled) async def bulk_update_where(self, update_data: list[dict], filters: dict): """ @param update_data a List of dictionary update values. @param filters is a dict of filter values. e.g [{'uid': 34, update_values}, ...] """ if not update_data or not filters: raise ValueError("Both update_data and filters are required to update model") query = self.model.smart_query(query=update(self.model), filters=filters) stmt = query.values(update_data).execution_options(synchronize_session="fetch") async with self.async_session() as session: results = await session.execute(stmt) updated = results.scalars().all() return updated async def bulk_update_with_mappings(self, mappings: list) -> None: """ @param mappings a List of dictionary update values with pks. e.g [{'uid': 34, update_values}, ...] ?? there must be zero many-to-many relations NB: Function does not return anything """ if not mappings: raise ValueError("No mappings provided to update") to_update = mappings # [marshaller(data) for data in mappings] for item in to_update: item["_uid"] = item["uid"] query = update(self.model).where(self.model.uid == bindparam("_uid")) binds = {} for key in to_update[0]: if key != "_uid": binds[key] = bindparam(key) stmt = query.values(binds).execution_options( synchronize_session=None ) # "fetch" not available async with self.async_session() as session: await session.execute(stmt, to_update) await session.flush() await session.commit() async def table_insert(self, table: Any, mappings: list[dict]) -> None: """ @param table is a sqlalchemy table model @param mappings a dictionary update values. e.g {'name': 34, 'day': "fff"} """ async with self.async_session() as session: stmt = table.insert() await session.execute(stmt, mappings) await session.commit() await session.flush() async def query_table(self, table: Table, columns: list[str], **kwargs): if table is None or not kwargs: raise ValueError("Both table and filters are required to query") if columns: stmt = select(*(table.c[column] for column in columns)) else: stmt = select(table) for k, v in kwargs.items(): stmt = stmt.where(table.c[k] == v) async with self.async_session() as session: results = await session.execute(stmt) return results.unique().scalars().all() # , results.keys() async def get(self, **kwargs) -> M: if not kwargs: raise ValueError("No arguments provided to get model") stmt = self.model.where(**kwargs) async with self.async_session() as session: results = await session.execute(stmt) found = results.scalars().first() return found async def get_all(self, **kwargs) -> list[M]: if not kwargs: raise ValueError("No arguments provided to get all") stmt = self.model.where(**kwargs) async with self.async_session() as session: results = await session.execute(stmt) found = results.scalars().all() return found async def all(self) -> list[M]: async with self.async_session() as session: results = await session.execute(select(self.model)) return results.scalars().all() async def all_by_page(self, page: int = 1, limit: int = 20, **kwargs) -> dict: start = (page - 1) * limit stmt = self.model.where(**kwargs).limit(limit).offset(start) async with self.async_session() as session: results = await session.execute(stmt) found = results.scalars().all() return found async def get_by_uids(self, uids: List[str]) -> list[M]: if not uids: raise ValueError("No uids provided to get by uids") stmt = select(self.model).where(self.model.uid.in_(uids)) # type: ignore async with self.async_session() as session: results = await session.execute(stmt.order_by(self.model.uid)) return results.scalars().all() async def get_related( self, related: Optional[list[str]], many: bool = False, **kwargs ): """Return the first value in database based on given args.""" if not related: raise ValueError("No related fields provided to get related") stmt = self.model.where(**kwargs) for key in related: stmt = apply_nested_loader_options(stmt, self.model, key) async with self.async_session() as session: results = await session.execute(stmt) if not many: found = results.scalars().first() else: found = results.scalars().all() return found async def stream_by_uids(self, uids: List[Any]) -> AsyncIterator[M]: stmt = select(self.model).where(self.model.in_(uids)) # type: ignore async with self.async_session() as session: stream = await session.stream(stmt.order_by(self.model.uid)) async for row in stream: yield row async def stream_all(self) -> AsyncIterator[Any]: stmt = select(self.model) async with self.async_session() as session: stream = await session.stream(stmt.order_by(self.model.uid)) async for row in stream: yield row async def full_text_search(self, search_string, field): """Full-text Search with PostgreSQL""" stmt = select(self.model).filter( func.to_tsvector("english", getattr(self.model, field)).match( search_string, postgresql_regconfig="english" ) ) async with self.async_session() as session: results = await session.execute(stmt) search = results.scalars().all() return search async def delete(self, uid: str) -> None: if not uid: raise ValueError("No uid provided to delete") obj = await self.get(uid=uid) async with self.async_session() as session: await session.delete(obj) await session.flush() await session.commit() async def count_where(self, filters: dict) -> int: """ :param filters: :return: int """ filter_stmt = self.model.smart_query(filters=filters) count_stmt = select(func.count(filter_stmt.c.uid)).select_from(filter_stmt) async with self.async_session() as session: res = await session.execute(count_stmt) return res.scalars().one() async def search(self, **kwargs) -> list[M]: if not kwargs: raise ValueError("No search arguments provided") filters = [] combined = set() for k, v in kwargs: filter_string = f"{k}__ilike" filters.append(filter_string) arg = dict() arg[k] = f"%{v}%" query = await self.get_all(**arg) for item in query: combined.add(item) return list(combined) async def filter( self, filters: list[dict], sort_attrs: list[str] | None = None, limit: int | None = None, either: bool = False, ) -> list[M]: if either: filters = {sa_or_: filters} stmt = self.model.smart_query(filters, sort_attrs) if limit: stmt = stmt.limit(limit) async with self.async_session() as session: results = await session.execute(stmt) found = results.scalars().all() return found async def paginate( self, page_size: int | None, after_cursor: str | None, before_cursor: str | None, filters: dict | list[dict] | None, sort_by: list[str] | None, **kwargs, ) -> PageCursor: if not filters: filters = {} # get total count without paging filters from cursors total_count: int = await self.count_where(filters=filters) total_count = total_count if total_count else 0 cursor_limit = {} if after_cursor: cursor_limit = {"uid__gt": self.decode_cursor(after_cursor)} if before_cursor: cursor_limit = {"uid__lt": self.decode_cursor(before_cursor)} # add paging filters _filters = None if isinstance(filters, dict): _filters = [{sa_or_: cursor_limit}, filters] if cursor_limit else filters elif isinstance(filters, list): _filters = filters if cursor_limit: _filters.append({sa_or_: cursor_limit}) stmt = self.model.smart_query(filters=_filters, sort_attrs=sort_by) if kwargs.get("get_related"): for key in kwargs.get("get_related"): # stmt = stmt.options(selectinload(getattr(self.model, key))) stmt = apply_nested_loader_options(stmt, self.model, key) if page_size: stmt = stmt.limit(page_size) async with self.async_session() as session: res = await session.execute(stmt) qs = res.scalars().all() if qs is not None: items = qs else: qs = [] items = [] has_additional = ( len(items) == page_size if page_size else True ) # len(qs) > len(items)s page_info = { "start_cursor": self.encode_cursor(items[0].uid) if items else None, "end_cursor": self.encode_cursor(items[-1].uid) if items else None, } if page_size is not None: page_info["has_next_page"] = has_additional page_info["has_previous_page"] = bool(after_cursor) return PageCursor( **{ "total_count": total_count, "edges": self.build_edges(items=items), "items": items, "page_info": self.build_page_info(**page_info), } ) def build_edges(self, items: List[Any]) -> List[EdgeNode]: if not items: return [] return [self.build_node(item) for item in items] def build_node(self, item: Any) -> EdgeNode: return EdgeNode(**{"cursor": self.encode_cursor(item.uid), "node": item}) @staticmethod def build_page_info( start_cursor: str = None, end_cursor: str = None, has_next_page: bool = False, has_previous_page: bool = False, ) -> PageInfo: return PageInfo( **{ "start_cursor": start_cursor, "end_cursor": end_cursor, "has_next_page": has_next_page, "has_previous_page": has_previous_page, } ) @staticmethod def decode_cursor(cursor): return cursor @staticmethod def encode_cursor(identifier: Any): return identifier