2021-10-01 15:37:34 +08:00
|
|
|
from typing import Dict, TypeVar, AsyncIterator, List, Any, Optional
|
2021-09-26 16:38:08 +08:00
|
|
|
from base64 import b64decode, b64encode
|
2021-04-04 15:57:22 +08:00
|
|
|
import logging
|
2021-01-06 19:52:14 +08:00
|
|
|
from pydantic import BaseModel as PydanticBaseModel
|
2021-09-20 17:27:40 +08:00
|
|
|
from sqlalchemy.future import select
|
2021-01-06 19:52:14 +08:00
|
|
|
from sqlalchemy import Column, Integer
|
2021-02-22 14:45:59 +08:00
|
|
|
from sqlalchemy.sql import func
|
2021-10-01 15:37:34 +08:00
|
|
|
from sqlalchemy.orm import selectinload, as_declarative, declared_attr
|
2021-09-26 16:38:08 +08:00
|
|
|
from sqlalchemy import or_ as sa_or_
|
|
|
|
from felicity.database.async_mixins import AllFeaturesMixin, TimestampsMixin, smart_query
|
|
|
|
from felicity.database.paginator.cursor import PageCursor, EdgeNode, PageInfo
|
2021-01-06 19:52:14 +08:00
|
|
|
|
2021-09-26 16:38:08 +08:00
|
|
|
from felicity.database.session import AsyncSessionScoped
|
|
|
|
from felicity.utils import has_value_or_is_truthy
|
2021-01-06 19:52:14 +08:00
|
|
|
|
|
|
|
InDBSchemaType = TypeVar("InDBSchemaType", bound=PydanticBaseModel)
|
|
|
|
|
2021-04-04 15:57:22 +08:00
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
logger = logging.getLogger(__name__)
|
2021-01-06 19:52:14 +08:00
|
|
|
|
2021-09-20 17:27:40 +08:00
|
|
|
|
2021-01-06 19:52:14 +08:00
|
|
|
# noinspection PyPep8Naming
|
|
|
|
class classproperty(object):
|
|
|
|
"""
|
|
|
|
@property for @classmethod
|
|
|
|
taken from http://stackoverflow.com/a/13624858
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, fget):
|
|
|
|
self.fget = fget
|
|
|
|
|
|
|
|
def __get__(self, owner_self, owner_cls):
|
|
|
|
return self.fget(owner_cls)
|
|
|
|
|
|
|
|
|
|
|
|
# Enhanced Base Model Class with some django-like super powers
|
|
|
|
@as_declarative()
|
2021-04-04 15:57:22 +08:00
|
|
|
class DBModel(AllFeaturesMixin, TimestampsMixin):
|
2021-01-06 19:52:14 +08:00
|
|
|
__name__: str
|
|
|
|
__abstract__ = True
|
2021-02-27 19:24:39 +08:00
|
|
|
|
2021-01-06 19:52:14 +08:00
|
|
|
uid = Column(Integer, primary_key=True, index=True, nullable=False, autoincrement=True)
|
2021-02-27 19:24:39 +08:00
|
|
|
|
2021-01-06 19:52:14 +08:00
|
|
|
# uid = Column(UUID(), default=uuid.uuid4, primary_key=True, unique=True, nullable=False)
|
2021-02-27 19:24:39 +08:00
|
|
|
|
2021-01-06 19:52:14 +08:00
|
|
|
# Generate __tablename__ automatically
|
|
|
|
@declared_attr
|
|
|
|
def __tablename__(cls) -> str:
|
|
|
|
return cls.__name__.lower()
|
|
|
|
|
2021-12-08 19:51:36 +08:00
|
|
|
def simple_marshal(self, exclude=None):
|
|
|
|
"""convert instance to dict
|
|
|
|
leverages instance.__dict__
|
|
|
|
"""
|
|
|
|
if exclude is None:
|
|
|
|
exclude = []
|
|
|
|
exclude.append("_sa_instance_state")
|
|
|
|
|
|
|
|
data = self.__dict__
|
|
|
|
return_data = {}
|
|
|
|
|
|
|
|
for field in data:
|
|
|
|
if field not in exclude:
|
|
|
|
return_data[field] = data[field]
|
|
|
|
|
|
|
|
return return_data
|
|
|
|
|
2021-01-06 19:52:14 +08:00
|
|
|
@classmethod
|
2021-09-20 17:27:40 +08:00
|
|
|
async def all_by_page(cls, page: int = 1, limit: int = 20, **kwargs) -> Dict:
|
2021-01-06 19:52:14 +08:00
|
|
|
start = (page - 1) * limit
|
|
|
|
end = start + limit
|
2021-09-20 17:27:40 +08:00
|
|
|
|
|
|
|
stmt = cls.where(**kwargs).limit(limit).offset(start)
|
|
|
|
results = await cls.session.execute(stmt)
|
|
|
|
found = results.scalars().all()
|
|
|
|
|
|
|
|
return found # cls.query.slice(start, end).all()
|
2021-01-06 19:52:14 +08:00
|
|
|
|
|
|
|
@classmethod
|
2021-09-20 17:27:40 +08:00
|
|
|
async def get(cls, **kwargs):
|
2021-01-06 19:52:14 +08:00
|
|
|
"""Return the the first value in database based on given args.
|
|
|
|
Example:
|
|
|
|
User.get(id=5)
|
|
|
|
"""
|
2021-09-20 17:27:40 +08:00
|
|
|
# stmt = select(cls).where(**kwargs)
|
|
|
|
stmt = cls.where(**kwargs)
|
|
|
|
results = await cls.session.execute(stmt)
|
|
|
|
found = results.scalars().first()
|
|
|
|
return found
|
2021-01-06 19:52:14 +08:00
|
|
|
|
2021-10-04 17:44:36 +08:00
|
|
|
@classmethod
|
|
|
|
async def create(cls, **kwargs):
|
|
|
|
"""Returns a new get instance of the class
|
|
|
|
This is so that mutations can work well and prevent asyc IO issues
|
|
|
|
"""
|
|
|
|
fill = cls().fill(**kwargs)
|
|
|
|
created = await cls.save(fill)
|
|
|
|
if created:
|
|
|
|
created = await cls.get(uid=created.uid)
|
|
|
|
return created
|
|
|
|
|
|
|
|
async def update(self, **kwargs):
|
|
|
|
"""Returns a new get instance of the class
|
|
|
|
This is so that mutations can work well and prevent asyc IO issues
|
|
|
|
"""
|
|
|
|
fill = self.fill(**kwargs)
|
|
|
|
updated = await fill.save()
|
|
|
|
if updated:
|
|
|
|
updated = await self.get(uid=updated.uid)
|
|
|
|
return updated
|
|
|
|
|
2021-10-01 15:37:34 +08:00
|
|
|
@classmethod
|
|
|
|
async def get_related(cls, related: Optional[list] = None, **kwargs):
|
|
|
|
"""Return the the first value in database based on given args.
|
|
|
|
Example:
|
|
|
|
User.get(id=5)
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
del kwargs['related']
|
|
|
|
except KeyError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
stmt = cls.where(**kwargs)
|
|
|
|
if related:
|
|
|
|
stmt.options(selectinload(related))
|
|
|
|
logger.info(stmt)
|
|
|
|
results = await cls.session.execute(stmt)
|
|
|
|
found = results.scalars().first()
|
|
|
|
return found
|
|
|
|
|
2021-01-06 19:52:14 +08:00
|
|
|
@classmethod
|
|
|
|
def _import(cls, schema_in: InDBSchemaType):
|
|
|
|
"""Convert Pydantic schema to dict"""
|
|
|
|
if isinstance(schema_in, dict):
|
|
|
|
return schema_in
|
|
|
|
data = schema_in.dict(exclude_unset=True)
|
|
|
|
return data
|
|
|
|
|
2021-09-20 17:27:40 +08:00
|
|
|
async def save(self):
|
2021-01-06 19:52:14 +08:00
|
|
|
"""Saves the updated model to the current entity db.
|
|
|
|
"""
|
2021-01-25 01:23:08 +08:00
|
|
|
try:
|
|
|
|
self.session.add(self)
|
2021-09-20 17:27:40 +08:00
|
|
|
await self.session.flush()
|
|
|
|
await self.session.commit()
|
2021-04-04 15:57:22 +08:00
|
|
|
except Exception as e:
|
2021-09-27 23:45:22 +08:00
|
|
|
logger.info(f"Rolling Back -> Session Save Error: {e}")
|
2021-09-20 17:27:40 +08:00
|
|
|
await self.session.rollback()
|
2021-01-25 01:23:08 +08:00
|
|
|
raise
|
2021-01-06 19:52:14 +08:00
|
|
|
return self
|
2021-02-27 19:24:39 +08:00
|
|
|
|
2021-01-06 19:52:14 +08:00
|
|
|
@classmethod
|
2021-09-20 17:27:40 +08:00
|
|
|
async def get_one(cls, **kwargs):
|
|
|
|
stmt = cls.where(**kwargs)
|
|
|
|
results = await cls.session.execute(stmt)
|
|
|
|
found = results.scalars().first()
|
|
|
|
return found
|
2021-02-27 19:24:39 +08:00
|
|
|
|
2021-10-01 15:37:34 +08:00
|
|
|
@classmethod
|
|
|
|
async def get_all(cls, **kwargs):
|
|
|
|
stmt = cls.where(**kwargs)
|
|
|
|
results = await cls.session.execute(stmt)
|
|
|
|
return results.scalars().all()
|
|
|
|
|
2021-09-26 16:38:08 +08:00
|
|
|
@classmethod
|
|
|
|
async def count_where(cls, filters):
|
|
|
|
# stmt = select(func.count(cls.uid))
|
|
|
|
# stmt = select(func.count('*')).select_from(cls)
|
|
|
|
# stmt = select(cls, func.count(cls.uid))
|
|
|
|
# stmt = select(cls).with_only_columns([func.count(cls.uid)]).order_by(None)
|
|
|
|
# stmt = select(func.count(cls.uid)).select_from(cls)
|
|
|
|
# stmt = smart_query(query=stmt, filters=filters)
|
|
|
|
stmt = smart_query(select(cls), filters=filters)
|
|
|
|
res = await cls.session.execute(stmt)
|
2021-09-27 23:45:22 +08:00
|
|
|
count = len(res.scalars().all())
|
2021-09-26 16:38:08 +08:00
|
|
|
return count
|
|
|
|
|
2021-02-22 14:45:59 +08:00
|
|
|
@classmethod
|
2021-09-20 17:27:40 +08:00
|
|
|
async def fulltext_search(cls, search_string, field):
|
2021-02-22 14:45:59 +08:00
|
|
|
"""Full-text Search with PostgreSQL"""
|
2021-09-20 17:27:40 +08:00
|
|
|
stmt = select(cls).filter(
|
|
|
|
func.to_tsvector('english', getattr(cls, field)).match(search_string, postgresql_regconfig='english')
|
|
|
|
)
|
|
|
|
results = await cls.session.execute(stmt)
|
|
|
|
search = results.scalars().all()
|
|
|
|
return search
|
|
|
|
|
|
|
|
@classmethod
|
2021-09-26 16:38:08 +08:00
|
|
|
async def get_by_uids(cls, uids: List[Any]) -> AsyncIterator[Any]:
|
2021-09-20 17:27:40 +08:00
|
|
|
stmt = (
|
|
|
|
select(cls)
|
2021-09-26 16:38:08 +08:00
|
|
|
.where(cls.uid.in_(uids)) # type: ignore
|
2021-09-20 17:27:40 +08:00
|
|
|
)
|
|
|
|
stream = await cls.session.stream(stmt.order_by(cls.uid))
|
|
|
|
async for row in stream:
|
|
|
|
yield row
|
2021-02-22 14:45:59 +08:00
|
|
|
|
2021-09-20 17:27:40 +08:00
|
|
|
@classmethod
|
|
|
|
async def stream_all(cls) -> AsyncIterator[Any]:
|
|
|
|
stmt = select(cls)
|
|
|
|
stream = await cls.session.stream(stmt.order_by(cls.uid))
|
|
|
|
async for row in stream:
|
|
|
|
yield row
|
|
|
|
|
|
|
|
@staticmethod
|
2021-01-06 19:52:14 +08:00
|
|
|
def psql_records_to_dict(self, records, many=False):
|
|
|
|
# records._row: asyncpg.Record / databases.backends.postgres.Record
|
|
|
|
if not many and records:
|
|
|
|
return dict(records)
|
|
|
|
return [dict(record) for record in records]
|
2021-02-27 19:24:39 +08:00
|
|
|
|
2021-09-26 16:38:08 +08:00
|
|
|
@classmethod
|
|
|
|
async def paginate_with_cursors(cls, page_size: [int] = None, after_cursor: Any = None, before_cursor: Any = None,
|
2021-09-27 23:45:22 +08:00
|
|
|
filters: Any = None, sort_by: List[str] = None) -> PageCursor:
|
2021-10-01 22:10:26 +08:00
|
|
|
|
2021-09-26 16:38:08 +08:00
|
|
|
if not filters:
|
|
|
|
filters = {}
|
|
|
|
|
|
|
|
# get total count without paging filters from cursors
|
|
|
|
total_count: int = await cls.count_where(filters=filters)
|
|
|
|
total_count = total_count if total_count else 0
|
|
|
|
|
|
|
|
cursor_limit = {}
|
|
|
|
if has_value_or_is_truthy(after_cursor):
|
|
|
|
cursor_limit = {'uid__gt': cls.decode_cursor(after_cursor)}
|
|
|
|
|
|
|
|
if has_value_or_is_truthy(before_cursor):
|
|
|
|
cursor_limit = {'uid__lt': cls.decode_cursor(before_cursor)}
|
|
|
|
|
|
|
|
# add paging filters
|
2021-11-28 18:42:56 +08:00
|
|
|
_filters = None
|
2021-09-26 16:38:08 +08:00
|
|
|
if isinstance(filters, dict):
|
|
|
|
_filters = [
|
|
|
|
{sa_or_: cursor_limit},
|
|
|
|
filters
|
|
|
|
] if cursor_limit else filters
|
|
|
|
elif isinstance(filters, list):
|
2021-10-01 22:10:26 +08:00
|
|
|
_filters = filters
|
|
|
|
if cursor_limit:
|
2021-11-28 18:42:56 +08:00
|
|
|
_filters.append({sa_or_: cursor_limit})
|
2021-10-01 22:10:26 +08:00
|
|
|
|
2021-09-26 16:38:08 +08:00
|
|
|
stmt = cls.smart_query(filters=_filters, sort_attrs=sort_by)
|
|
|
|
qs = (await cls.session.execute(stmt)).scalars().all()
|
|
|
|
|
|
|
|
if qs is not None:
|
|
|
|
items = qs[:page_size]
|
|
|
|
else:
|
|
|
|
qs = []
|
|
|
|
items = []
|
|
|
|
|
|
|
|
has_additional = len(qs) > len(items)
|
|
|
|
page_info = {
|
|
|
|
'start_cursor': cls.encode_cursor(items[0].uid) if items else None,
|
|
|
|
'end_cursor': cls.encode_cursor(items[-1].uid) if items else None,
|
|
|
|
}
|
|
|
|
if page_size is not None:
|
|
|
|
page_info['has_next_page'] = has_additional
|
|
|
|
page_info['has_previous_page'] = bool(after_cursor)
|
|
|
|
|
|
|
|
return PageCursor(**{
|
|
|
|
'total_count': total_count,
|
|
|
|
'edges': cls.build_edges(items=items),
|
|
|
|
'items': items,
|
|
|
|
'page_info': cls.build_page_info(**page_info)
|
|
|
|
})
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build_edges(cls, items: List[Any]) -> List[EdgeNode]:
|
|
|
|
if not items:
|
|
|
|
return []
|
|
|
|
return [cls.build_node(item) for item in items]
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build_node(cls, item: Any) -> EdgeNode:
|
|
|
|
return EdgeNode(**{
|
|
|
|
'cursor': cls.encode_cursor(item.uid),
|
|
|
|
'node': item
|
|
|
|
})
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def build_page_info(cls, start_cursor: str = None, end_cursor: str = None,
|
|
|
|
has_next_page: bool = False, has_previous_page: bool = False) -> PageInfo:
|
|
|
|
return PageInfo(**{
|
|
|
|
'start_cursor': start_cursor,
|
|
|
|
'end_cursor': end_cursor,
|
|
|
|
'has_next_page': has_next_page,
|
|
|
|
'has_previous_page': has_previous_page
|
|
|
|
})
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def decode_cursor(cls, cursor):
|
|
|
|
decoded = b64decode(cursor.encode('ascii')).decode('utf8')
|
|
|
|
try:
|
|
|
|
return int(decoded)
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning(e)
|
|
|
|
return decoded
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def encode_cursor(cls, identifier: Any):
|
|
|
|
return b64encode(str(identifier).encode('utf8')).decode('ascii')
|
|
|
|
|
2021-02-27 19:24:39 +08:00
|
|
|
|
2021-09-26 16:38:08 +08:00
|
|
|
DBModel.set_session(AsyncSessionScoped())
|